diff --git a/.deepsource.toml b/.deepsource.toml new file mode 100644 index 0000000000..8499099f65 --- /dev/null +++ b/.deepsource.toml @@ -0,0 +1,12 @@ +version = 1 + +test_patterns = ["**/*_test.go"] + +exclude_patterns = ["scripts/**"] + +[[analyzers]] +name = "go" +enabled = true + + [analyzers.meta] + import_paths = ["github.com/beego/beego/v2"] diff --git a/.github/ISSUE_TEMPLATE b/.github/ISSUE_TEMPLATE index db349198da..a1e5c291a5 100644 --- a/.github/ISSUE_TEMPLATE +++ b/.github/ISSUE_TEMPLATE @@ -1,17 +1,23 @@ +**English Only**. Please use English because others could join the discussion if they got similar issue! + Please answer these questions before submitting your issue. Thanks! -1. What version of Go and beego are you using (`bee version`)? +1. What did you do? +> If possible, provide a recipe for reproducing the error. +> A complete runnable program is good. +> If this is ORM issue, please provide the DB schemas. -2. What operating system and processor architecture are you using (`go env`)? +2. What did you expect to see? +3. What did you see instead? +> please provide log or error information. -3. What did you do? -If possible, provide a recipe for reproducing the error. -A complete runnable program is good. +4. How to reproduce the issue? +> or you can provide a reproduce demo. -4. What did you expect to see? +5. What version of Go and beego are you using (`bee version`)? +6. What operating system and processor architecture are you using (`go env`)? -5. What did you see instead? \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..d85c63734e --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,17 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "daily" + open-pull-requests-limit: 10 + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/follow-template.yml b/.github/workflows/follow-template.yml new file mode 100644 index 0000000000..73e3210dc1 --- /dev/null +++ b/.github/workflows/follow-template.yml @@ -0,0 +1,19 @@ +name: issue-must-follow-template + +on: + schedule: + - cron: "0 * * * *" # pick a cron here, this is every 1h + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: luanpotter/changes-requested@master + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + # these are optional, if you want to configure: + days-until-close: 5 + trigger-label: status/no-follow-template + closing-comment: This issue was closed by the follow-template bot. Please follow the issue template to provide necessary info. + dry-run: false \ No newline at end of file diff --git a/.github/workflows/need-feedback.yml b/.github/workflows/need-feedback.yml new file mode 100644 index 0000000000..7754c9d035 --- /dev/null +++ b/.github/workflows/need-feedback.yml @@ -0,0 +1,21 @@ +name: need-feeback-issues + +on: + schedule: + - cron: "0 * * * *" # pick a cron here, this is every 1h + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: luanpotter/changes-requested@master + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + # these are optional, if you want to configure: + days-until-close: 5 + trigger-label: status/need-feedback + closing-comment: | + This issue was closed by the need-feedback bot. + @${issue.user.login}, please follow the issue template/discussion to provide more details. + dry-run: false \ No newline at end of file diff --git a/.github/workflows/need-translation.yml b/.github/workflows/need-translation.yml new file mode 100644 index 0000000000..df153b2cd8 --- /dev/null +++ b/.github/workflows/need-translation.yml @@ -0,0 +1,19 @@ +name: need-translation-issues + +on: + schedule: + - cron: "0 * * * *" # pick a cron here, this is every 1h + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: luanpotter/changes-requested@master + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + # these are optional, if you want to configure: + days-until-close: 5 + trigger-label: status/need-translation + closing-comment: This issue was closed by the need-translation bot. Please translate your issue to English so it could help others. + dry-run: false \ No newline at end of file diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000..dd490e9a1f --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,25 @@ +name: Mark stale issues and pull requests + +on: + schedule: + - cron: "30 1 * * *" + +permissions: + contents: read + +jobs: + stale: + + permissions: + issues: write # for actions/stale to close stale issues + pull-requests: write # for actions/stale to close stale PRs + runs-on: ubuntu-latest + + steps: + - uses: actions/stale@v9 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: 'This issue is inactive for a long time.' + stale-pr-message: 'This PR is inactive for a long time' + stale-issue-label: 'inactive-issue' + stale-pr-label: 'inactive-pr' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000000..ede3ed863d --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,127 @@ +name: Test +on: + push: + branches: + - master + - develop + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".github/workflows/test.yml" + pull_request: + types: [opened, synchronize, reopened] + branches: + - master + - develop + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".github/workflows/test.yml" + +permissions: + contents: read + +jobs: + test: + strategy: + fail-fast: false + matrix: + go-version: ["1.20",1.21,1.22] + runs-on: ubuntu-latest + services: + redis: + image: redis:latest + ports: + - 6379:6379 + memcached: + image: memcached:latest + ports: + - 11211:11211 + ssdb: + image: tsl0922/ssdb + env: + SSDB_PORT: 8888 + ports: + - "8888:8888" + postgres: + image: postgres:latest + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: orm_test + ports: + - 5432/tcp + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + + steps: + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + + - name: Checkout codebase + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Run etcd + env: + ETCD_VERSION: v3.4.16 + run: | + rm -rf /tmp/etcd-data.tmp + mkdir -p /tmp/etcd-data.tmp + docker rmi gcr.io/etcd-development/etcd:${ETCD_VERSION} || true && \ + docker run -d \ + -p 2379:2379 \ + -p 2380:2380 \ + --mount type=bind,source=/tmp/etcd-data.tmp,destination=/etcd-data \ + --name etcd-gcr-${ETCD_VERSION} \ + gcr.io/etcd-development/etcd:${ETCD_VERSION} \ + /usr/local/bin/etcd \ + --name s1 \ + --data-dir /etcd-data \ + --listen-client-urls http://0.0.0.0:2379 \ + --advertise-client-urls http://0.0.0.0:2379 \ + --listen-peer-urls http://0.0.0.0:2380 \ + --initial-advertise-peer-urls http://0.0.0.0:2380 \ + --initial-cluster s1=http://0.0.0.0:2380 \ + --initial-cluster-token tkn \ + --initial-cluster-state new + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.float 1.23" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.bool true" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.int 11" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.string hello" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put current.serialize.name test" + docker exec etcd-gcr-${ETCD_VERSION} /bin/sh -c "ETCDCTL_API=3 /usr/local/bin/etcdctl put sub.sub.key1 sub.sub.key" + + - name: Run ORM tests on sqlite3 + env: + GOPATH: /home/runner/go + ORM_DRIVER: sqlite3 + ORM_SOURCE: /tmp/sqlite3/orm_test.db + run: | + mkdir -p /tmp/sqlite3 && touch /tmp/sqlite3/orm_test.db + go test -coverprofile=coverage_sqlite3.txt -covermode=atomic $(go list ./... | grep client/orm) + + - name: Run ORM tests on postgres + env: + GOPATH: /home/runner/go + ORM_DRIVER: postgres + ORM_SOURCE: host=localhost port=${{ job.services.postgres.ports[5432] }} user=postgres password=postgres dbname=orm_test sslmode=disable + run: | + go test -coverprofile=coverage_postgres.txt -covermode=atomic $(go list ./... | grep client/orm) + + - name: Run tests on mysql + env: + GOPATH: /home/runner/go + ORM_DRIVER: mysql + ORM_SOURCE: root:root@/orm_test?charset=utf8 + run: | + sudo systemctl start mysql + mysql -u root -proot -e 'create database orm_test;' + go test -coverprofile=coverage.txt -covermode=atomic ./... + + - name: Upload codecov + run: bash <(curl -s https://codecov.io/bash) diff --git a/.gitignore b/.gitignore index e1b6529101..9a2af3cfca 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,14 @@ *.swp *.swo beego.iml + +_beeTmp/ +_beeTmp2/ +pkg/_beeTmp/ +pkg/_beeTmp2/ +test/tmp/ +core/config/env/pkg/ + +my save path/ + +profile.out diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 1bb121a21f..0000000000 --- a/.travis.yml +++ /dev/null @@ -1,74 +0,0 @@ -language: go - -go: - - "1.11.x" -services: - - redis-server - - mysql - - postgresql - - memcached -env: - global: - - GO_REPO_FULLNAME="github.com/astaxie/beego" - matrix: - - ORM_DRIVER=sqlite3 ORM_SOURCE=$TRAVIS_BUILD_DIR/orm_test.db - - ORM_DRIVER=postgres ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" -before_install: - # link the local repo with ${GOPATH}/src// - - GO_REPO_NAMESPACE=${GO_REPO_FULLNAME%/*} - # relies on GOPATH to contain only one directory... - - mkdir -p ${GOPATH}/src/${GO_REPO_NAMESPACE} - - ln -sv ${TRAVIS_BUILD_DIR} ${GOPATH}/src/${GO_REPO_FULLNAME} - - cd ${GOPATH}/src/${GO_REPO_FULLNAME} - # get and build ssdb - - git clone git://github.com/ideawu/ssdb.git - - cd ssdb - - make - - cd .. -install: - - go get github.com/lib/pq - - go get github.com/go-sql-driver/mysql - - go get github.com/mattn/go-sqlite3 - - go get github.com/bradfitz/gomemcache/memcache - - go get github.com/gomodule/redigo/redis - - go get github.com/beego/x2j - - go get github.com/couchbase/go-couchbase - - go get github.com/beego/goyaml2 - - go get gopkg.in/yaml.v2 - - go get github.com/belogik/goes - - go get github.com/siddontang/ledisdb/config - - go get github.com/siddontang/ledisdb/ledis - - go get github.com/ssdb/gossdb/ssdb - - go get github.com/cloudflare/golz4 - - go get github.com/gogo/protobuf/proto - - go get github.com/Knetic/govaluate - - go get github.com/casbin/casbin - - go get github.com/elazarl/go-bindata-assetfs - - go get github.com/OwnLocal/goes - - go get github.com/shiena/ansicolor - - go get -u honnef.co/go/tools/cmd/staticcheck - - go get -u github.com/mdempsky/unconvert - - go get -u github.com/gordonklaus/ineffassign - - go get -u github.com/golang/lint/golint - - go get -u github.com/go-redis/redis -before_script: - - psql --version - - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - - sh -c "if [ '$ORM_DRIVER' = 'mysql' ]; then mysql -u root -e 'create database orm_test;'; fi" - - sh -c "if [ '$ORM_DRIVER' = 'sqlite' ]; then touch $TRAVIS_BUILD_DIR/orm_test.db; fi" - - sh -c "go get github.com/golang/lint/golint; golint ./...;" - - sh -c "go list ./... | grep -v vendor | xargs go vet -v" - - mkdir -p res/var - - ./ssdb/ssdb-server ./ssdb/ssdb.conf -d -after_script: - - killall -w ssdb-server - - rm -rf ./res/var/* -script: - - go test -v ./... - - staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" - - unconvert $(go list ./... | grep -v /vendor/) - - ineffassign . - - find . ! \( -path './vendor' -prune \) -type f -name '*.go' -print0 | xargs -0 gofmt -l -s - - golint ./... -addons: - postgresql: "9.6" diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..12d69b38d7 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,199 @@ +# Change Log is deprecated + +Please refer to https://github.com/beego/beego/releases to find the CHANGELOG + +# v2.1.2 +- [refactor: CONTRIBUTING.md file grammatical improvements](https://github.com/beego/beego/issues/5411) +- [fix: refactor Count method](https://github.com/beego/beego/pull/5300) +- [support db_type in ddl ](https://github.com/beego/beego/pull/5404) +- [orm: PostgreSQL change auto to bigserial](https://github.com/beego/beego/pull/5415) + +# v2.1.1 +- [httplib: fix unstable unit test which use the httplib.org](https://github.com/beego/beego/pull/5232) +- [rft: remove adapter package](https://github.com/beego/beego/pull/5239) +- [feat: add write-delete cache mode](https://github.com/beego/beego/pull/5242) +- [feat: add write-double-delete cache mode](https://github.com/beego/beego/pull/5243) +- [fix 5255: Check the rows.Err() if rows.Next() is false](https://github.com/beego/beego/pull/5256) +- [orm: missing handling %COL% placeholder](https://github.com/beego/beego/pull/5257) +- [fix: use of ioutil package](https://github.com/beego/beego/pull/5261) +- [cache/redis: support skipEmptyPrefix option ](https://github.com/beego/beego/pull/5264) +- [fix: refactor InsertValue method](https://github.com/beego/beego/pull/5267) +- [fix: modify InsertOrUpdate method, Remove the isMulti variable and its associated code](https://github.com/beego/beego/pull/5269) +- [refactor cache/redis: Use redisConfig to receive incoming JSON (previously using a map)](https://github.com/beego/beego/pull/5268) +- [fix: refactor DeleteSQL method](https://github.com/beego/beego/pull/5271) +- [fix: refactor UpdateSQL method](https://github.com/beego/beego/pull/5274) +- [fix: refactor UpdateBatch method](https://github.com/beego/beego/pull/5295) +- [fix: refactor InsertOrUpdate method](https://github.com/beego/beego/pull/5296) +- [fix: refactor ReadBatch method](https://github.com/beego/beego/pull/5298) +- [fix: refactor ReadValues method](https://github.com/beego/beego/pull/5303) + +## ORM refactoring +- [introducing internal/models pkg](https://github.com/beego/beego/pull/5238) + +# v2.1.0 +- [unified gopkg.in/yaml version to v2](https://github.com/beego/beego/pull/5169) +- [add non-block write log in asynchronous mode](https://github.com/beego/beego/pull/5150) +- [Fix 5126: support bloom filter cache](https://github.com/beego/beego/pull/5126) +- [Fix 5117: support write though cache](https://github.com/beego/beego/pull/5117) +- [add read through for cache module](https://github.com/beego/beego/pull/5116) +- [add singleflight cache for cache module](https://github.com/beego/beego/pull/5119) +- [Fix 5129: must set formatter after init the logger](https://github.com/beego/beego/pull/5130) +- [Fix 5079: only log msg when the channel is not closed](https://github.com/beego/beego/pull/5132) +- [Fix 4435: Controller SaveToFile remove all temp file](https://github.com/beego/beego/pull/5138) +- [Fix 5079: Split signalChan into flushChan and closeChan](https://github.com/beego/beego/pull/5139) +- [Fix 5172: protect field access with lock to avoid possible data race](https://github.com/beego/beego/pull/5210) +- [cache: fix typo and optimize the naming]() +- [Fix 5176: beegoAppConfig String and Strings function has bug](https://github.com/beego/beego/pull/5211) + +# v2.1.0 +- [unified gopkg.in/yaml version to v2](https://github.com/beego/beego/pull/5169) +- [add non-block write log in asynchronous mode](https://github.com/beego/beego/pull/5150) +- [Fix 5126: support bloom filter cache](https://github.com/beego/beego/pull/5126) +- [Fix 5117: support write though cache](https://github.com/beego/beego/pull/5117) +- [add read through for cache module](https://github.com/beego/beego/pull/5116) +- [add singleflight cache for cache module](https://github.com/beego/beego/pull/5119) +- [Fix 5129: must set formatter after init the logger](https://github.com/beego/beego/pull/5130) +- [Fix 5079: only log msg when the channel is not closed](https://github.com/beego/beego/pull/5132) +- [Fix 4435: Controller SaveToFile remove all temp file](https://github.com/beego/beego/pull/5138) +- [Fix 5079: Split signalChan into flushChan and closeChan](https://github.com/beego/beego/pull/5139) +- [Fix 5172: protect field access with lock to avoid possible data race](https://github.com/beego/beego/pull/5210) +- [cache: fix typo and optimize the naming]() + +# v2.0.7 +- [Upgrade github.com/go-kit/kit, CVE-2022-24450](https://github.com/beego/beego/pull/5121) +# v2.0.6 +- [fix: revise the body wrapper to handle empty body case](https://github.com/beego/beego/pull/5102) + +# v2.0.5 + +Note: now we force the web admin service serving HTTP only. + +- [Fix 4984: random expire cache](https://github.com/beego/beego/pull/4984) +- [Fix 4907: make admin serve HTTP only](https://github.com/beego/beego/pull/5005) +- [Feat 4999: add get all tasks function](https://github.com/beego/beego/pull/4999) +- [Fix 5012: fix some bug, pass []any as any in variadic function](https://github.com/beego/beego/pull/5012) +- [Fix 5022: Miss assigning listener to graceful Server](https://github.com/beego/beego/pull/5028) +- [Fix 4955: Make commands and Docker compose for ORM unit tests](https://github.com/beego/beego/pull/5031) + +# v2.0.4 + +Note: now we force the web admin service serving HTTP only. + +- [Fix issue 4961, `leafInfo.match()` use `path.join()` to deal with `wildcardValues`, which may lead to cross directory risk ](https://github.com/beego/beego/pull/4964) +- [Fix 4975: graceful server listen the specific address](https://github.com/beego/beego/pull/4979) +- [Fix 4976: make admin serve HTTP only](https://github.com/beego/beego/pull/4980) + +# v2.0.3 +- [upgrade redisgo to v1.8.8](https://github.com/beego/beego/pull/4872) +- [fix prometheus CVE-2022-21698](https://github.com/beego/beego/pull/4878) +- [upgrade to Go 1.18](https://github.com/beego/beego/pull/4896) +- [make `PatternLogFormatter` handling the arguments](https://github.com/beego/beego/pull/4914/files) +- [Add httplib OpenTelemetry Filter](https://github.com/beego/beego/pull/4888, https://github.com/beego/beego/pull/4915) +- [Add orm OpenTelemetry Filter](https://github.com/beego/beego/issues/4944) +- [Support NewBeegoRequestWithCtx in httplib](https://github.com/beego/beego/pull/4895) +- [Support lifecycle callback](https://github.com/beego/beego/pull/4918) +- [Append column comments to create table sentence when using postgres](https://github.com/beego/beego/pull/4940) +- [logs: multiFileLogWriter uses incorrect formatter](https://github.com/beego/beego/pull/4943) +- [fix issue 4946 CVE-2022-31259](https://github.com/beego/beego/pull/4954) + +# v2.0.2 +See v2.0.2-beta.1 +- [fix bug: etcd should use etcd as adapter name](https://github.com/beego/beego/pull/4845) +# v2.0.2-beta.1 +- Add a custom option for whether to escape HTML special characters when processing http request parameters. [4701](https://github.com/beego/beego/pull/4701) +- Always set the response status in the CustomAbort function. [4686](https://github.com/beego/beego/pull/4686) +- Add template functions eq,lt to support uint and int compare. [4607](https://github.com/beego/beego/pull/4607) +- Migrate tests to GitHub Actions. [4663](https://github.com/beego/beego/issues/4663) +- Add http client and option func. [4455](https://github.com/beego/beego/issues/4455) +- Add: Convenient way to generate mock object [4620](https://github.com/beego/beego/issues/4620) +- Infra: use dependabot to update dependencies. [4623](https://github.com/beego/beego/pull/4623) +- Lint: use golangci-lint. [4619](https://github.com/beego/beego/pull/4619) +- Chore: format code. [4615](https://github.com/beego/beego/pull/4615) +- Test on Go v1.15.x & v1.16.x. [4614](https://github.com/beego/beego/pull/4614) +- Env: non-empty GOBIN & GOPATH. [4613](https://github.com/beego/beego/pull/4613) +- Chore: update dependencies. [4611](https://github.com/beego/beego/pull/4611) +- Update orm_test.go/TestInsertOrUpdate with table-driven. [4609](https://github.com/beego/beego/pull/4609) +- Add: Resp() method for web.Controller. [4588](https://github.com/beego/beego/pull/4588) +- Web mock and test support. [4565](https://github.com/beego/beego/pull/4565) [4574](https://github.com/beego/beego/pull/4574) +- Error codes definition of cache module. [4493](https://github.com/beego/beego/pull/4493) +- Remove generateCommentRoute http hook. Using `bee generate routers` commands instead.[4486](https://github.com/beego/beego/pull/4486) [bee PR 762](https://github.com/beego/bee/pull/762) +- Fix: /abc.html/aaa match /abc/aaa. [4459](https://github.com/beego/beego/pull/4459) +- ORM mock. [4407](https://github.com/beego/beego/pull/4407) +- Add sonar check and ignore test. [4432](https://github.com/beego/beego/pull/4432) [4433](https://github.com/beego/beego/pull/4433) +- Update changlog.yml to check every PR to develop branch.[4427](https://github.com/beego/beego/pull/4427) +- Fix 4396: Add context.param module into adapter. [4398](https://github.com/beego/beego/pull/4398) +- Support `RollbackUnlessCommit` API. [4542](https://github.com/beego/beego/pull/4542) +- Fix 4503 and 4504: Add `when` to `Write([]byte)` method and add `prefix` to `writeMsg`. [4507](https://github.com/beego/beego/pull/4507) +- Fix 4480: log format incorrect. [4482](https://github.com/beego/beego/pull/4482) +- Remove `duration` from prometheus labels. [4391](https://github.com/beego/beego/pull/4391) +- Fix `unknown escape sequence` in generated code. [4385](https://github.com/beego/beego/pull/4385) +- Fix 4590: Forget to check URL when FilterChain invoke `next()`. [4593](https://github.com/beego/beego/pull/4593) +- Fix 4727: CSS when request URI is invalid. [4729](https://github.com/beego/beego/pull/4729) +- Using fixed name `commentRouter.go` as generated file name. [4385](https://github.com/beego/beego/pull/4385) +- Fix 4383: ORM Adapter produces panic when using orm.RegisterModelWithPrefix. [4386](https://github.com/beego/beego/pull/4386) +- Support 4144: Add new api for order by for supporting multiple way to query [4294](https://github.com/beego/beego/pull/4294) +- Support session Filter chain. [4404](https://github.com/beego/beego/pull/4404) +- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) +- Implement context.Context support and deprecate `QueryM2MWithCtx` and `QueryTableWithCtx` [4424](https://github.com/beego/beego/pull/4424) +- Finish timeout option for tasks #4441 [4441](https://github.com/beego/beego/pull/4441) +- Error Module brief design & using httplib module to validate this design. [4453](https://github.com/beego/beego/pull/4453) +- Fix 4444: panic when 404 not found. [4446](https://github.com/beego/beego/pull/4446) +- Fix 4435: fix panic when controller dir not found. [4452](https://github.com/beego/beego/pull/4452) +- Hotfix:reflect.ValueOf(nil) in getFlatParams [4716](https://github.com/beego/beego/issues/4716) +- Fix 4456: Fix router method expression [4456](https://github.com/beego/beego/pull/4456) +- Remove some `go get` lines in `.travis.yml` file [4469](https://github.com/beego/beego/pull/4469) +- Fix 4451: support QueryExecutor interface. [4461](https://github.com/beego/beego/pull/4461) +- Add some testing scripts [4461](https://github.com/beego/beego/pull/4461) +- Refactor httplib: Move debug code to a filter [4440](https://github.com/beego/beego/issues/4440) +- fix: code quality issues [4513](https://github.com/beego/beego/pull/4513) +- Optimize maligned structs to reduce memory foot-print [4525](https://github.com/beego/beego/pull/4525) +- Feat: add token bucket ratelimit filter [4508](https://github.com/beego/beego/pull/4508) +- Improve: Avoid ignoring mistakes that need attention [4548](https://github.com/beego/beego/pull/4548) +- Integration: DeepSource [4560](https://github.com/beego/beego/pull/4560) +- Integration: Remove unnecessary function call [4577](https://github.com/beego/beego/pull/4577) +- Feature issue #4402 finish router get example. [4416](https://github.com/beego/beego/pull/4416) +- Proposal: Add Bind() method for `web.Controller` [4491](https://github.com/beego/beego/issues/4579) +- Optimize AddAutoPrefix: only register one router in case-insensitive mode. [4582](https://github.com/beego/beego/pull/4582) +- Init exceptMethod by using reflection. [4583](https://github.com/beego/beego/pull/4583) +- Deprecated BeeMap and replace all usage with `sync.map` [4616](https://github.com/beego/beego/pull/4616) +- TaskManager support graceful shutdown [4635](https://github.com/beego/beego/pull/4635) +- Add comments to `web.Config`, rename `RouterXXX` to `CtrlXXX`, define `HandleFunc` [4714](https://github.com/beego/beego/pull/4714) +- Refactor: Move `BindXXX` and `XXXResp` methods to `context.Context`. [4718](https://github.com/beego/beego/pull/4718) +- Fix 4728: Print wrong file name. [4737](https://github.com/beego/beego/pull/4737) +- fix bug:reflect.ValueOf(nil) in getFlatParams [4715](https://github.com/beego/beego/pull/4715) +- Fix 4736: set a fixed value "/" to the "Path" of "_xsrf" cookie. [4736](https://github.com/beego/beego/issues/4735) [4739](https://github.com/beego/beego/issues/4739) +- Fix 4734: do not reset id in Delete function. [4738](https://github.com/beego/beego/pull/4738) [4742](https://github.com/beego/beego/pull/4742) +- Fix 4699: Remove Remove goyaml2 dependency. [4755](https://github.com/beego/beego/pull/4755) +- Fix 4698: Prompt error when config format is incorrect. [4757](https://github.com/beego/beego/pull/4757) +- Fix 4674: Tx Orm missing debug log [4756](https://github.com/beego/beego/pull/4756) +- Fix 4759: fix numeric notation of permissions [4759](https://github.com/beego/beego/pull/4759) +- set default rate and capacity for ratelimit filter [4796](https://github.com/beego/beego/pull/4796) +- Fix 4782: must set status before rendering error page [4797](https://github.com/beego/beego/pull/4797) +- Fix 4791: delay to format parameter in log module [4804](https://github.com/beego/beego/pull/4804) + +## Fix Sonar + +- [4677](https://github.com/beego/beego/pull/4677) +- [4624](https://github.com/beego/beego/pull/4624) +- [4608](https://github.com/beego/beego/pull/4608) +- [4473](https://github.com/beego/beego/pull/4473) +- [4474](https://github.com/beego/beego/pull/4474) +- [4479](https://github.com/beego/beego/pull/4479) +- [4639](https://github.com/beego/beego/pull/4639) +- [4668](https://github.com/beego/beego/pull/4668) + +## Fix lint and format code + +- [4644](https://github.com/beego/beego/pull/4644) +- [4645](https://github.com/beego/beego/pull/4645) +- [4646](https://github.com/beego/beego/pull/4646) +- [4647](https://github.com/beego/beego/pull/4647) +- [4648](https://github.com/beego/beego/pull/4648) +- [4649](https://github.com/beego/beego/pull/4649) +- [4651](https://github.com/beego/beego/pull/4651) +- [4652](https://github.com/beego/beego/pull/4652) +- [4653](https://github.com/beego/beego/pull/4653) +- [4654](https://github.com/beego/beego/pull/4654) +- [4655](https://github.com/beego/beego/pull/4655) +- [4656](https://github.com/beego/beego/pull/4656) +- [4660](https://github.com/beego/beego/pull/4660) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9d51161652..3974985038 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,37 +1,75 @@ # Contributing to beego -beego is an open source project. +Beego is an open-source project. -It is the work of hundreds of contributors. We appreciate your help! +It is the work of hundreds of contributors. And you could be among them, so we appreciate your help! -Here are instructions to get you started. They are probably not perfect, -please let us know if anything feels wrong or incomplete. +Here are instructions to get you started. They are probably not perfect, so please let us know if anything feels wrong or +incomplete. + +## Prepare environment + +Firstly, you need to install some tools. Execute the commands below **outside** the project. Otherwise, this action will modify the go.mod file. + +```shell script +go get -u golang.org/x/tools/cmd/goimports + +go get -u github.com/gordonklaus/ineffassign +``` + +Put the lines below in your pre-commit git hook script: + +```shell script +goimports -w -format-only ./ + +ineffassign . + +staticcheck -show-ignored -checks "-ST1017,-U1000,-ST1005,-S1034,-S1012,-SA4006,-SA6005,-SA1019,-SA1024" ./ +``` + +## Prepare middleware + +Beego uses many middlewares, including MySQL, Redis, SSDB amongs't others. + +We provide a docker-compose file to start all middlewares. + +You can run the following command to start all middlewares: + +```shell script +docker-compose -f scripts/test_docker_compose.yaml up -d +``` + +Unit tests read addresses from environmental variables, you can set them up as shown in the example below: + +```shell script +export ORM_DRIVER=mysql +export ORM_SOURCE="beego:test@tcp(192.168.0.105:13306)/orm_test?charset=utf8" +export MEMCACHE_ADDR="192.168.0.105:11211" +export REDIS_ADDR="192.168.0.105:6379" +export SSDB_ADDR="192.168.0.105:8888" +``` ## Contribution guidelines ### Pull requests -First of all. beego follow the gitflow. So please send you pull request -to **develop** branch. We will close the pull request to master branch. +Beego follows the gitflow. And as such, please submit your pull request to the **develop** branch. We will close the pull request by merging it into the master branch. + +**NOTE:** Don't forget to update the `CHANGELOG.md` file by adding the changes made under the **developing** section. +We'll release them in the next Beego version. -We are always happy to receive pull requests, and do our best to -review them as fast as possible. Not sure if that typo is worth a pull -request? Do it! We will appreciate it. +We are always happy to receive pull requests, and do our best to review them as fast as possible. Not sure if that typo is worth a pull request? Just do it! We will appreciate it. -If your pull request is not accepted on the first try, don't be -discouraged! Sometimes we can make a mistake, please do more explaining -for us. We will appreciate it. +Don't forget to rebase your commits! -We're trying very hard to keep beego simple and fast. We don't want it -to do everything for everybody. This means that we might decide against -incorporating a new feature. But we will give you some advice on how to -do it in other way. +If your pull request is rejected, dont be discouraged. Sometimes we make mistakes. You can provide us with more context by explaining your issue as clearly as possible. + +In our pursuit of maintaining Beego's simplicity and speed, we might not accept some feature requests. We don't want it to do everything for everybody. For this reason, we might decide against incorporating a new feature. However, we will provide guidance on achieving the same thing using a different approach ### Create issues -Any significant improvement should be documented as [a GitHub -issue](https://github.com/astaxie/beego/issues) before anybody -starts working on it. +Any significant improvement should be documented as [a GitHub issue](https://github.com/beego/beego/v2/issues) before +anybody starts working on it. Also when filing an issue, make sure to answer these five questions: @@ -43,10 +81,6 @@ Also when filing an issue, make sure to answer these five questions: ### but check existing issues and docs first! -Please take a moment to check that an issue doesn't already exist -documenting your bug report or improvement proposal. If it does, it -never hurts to add a quick "+1" or "I have this problem too". This will -help prioritize the most common problems and requests. - -Also if you don't know how to use it. please make sure you have read though -the docs in http://beego.me/docs \ No newline at end of file +Take a moment to check that an issue documenting your bug report or improvement proposal doesn't already exist. +If it does, it doesn't hurts to add a quick "+1" or "I have this problem too". This will help prioritize the most common +problems and requests. diff --git a/ERROR_SPECIFICATION.md b/ERROR_SPECIFICATION.md new file mode 100644 index 0000000000..68a04bd11a --- /dev/null +++ b/ERROR_SPECIFICATION.md @@ -0,0 +1,5 @@ +# Error Module + +## Module code +- httplib 1 +- cache 2 \ No newline at end of file diff --git a/LICENSE b/LICENSE index 5dbd424355..b947dac316 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2014 astaxie +Copyright 2014 Beego Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -10,4 +10,4 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and -limitations under the License. \ No newline at end of file +limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..e994ec4cd7 --- /dev/null +++ b/Makefile @@ -0,0 +1,68 @@ +# Copyright 2020 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +##@ General + +# The help target prints out all targets with their descriptions organized +# beneath their categories. The categories are represented by '##@' and the +# target descriptions by '##'. The awk commands is responsible for reading the +# entire set of makefiles included in this invocation, looking for lines of the +# file as xyz: ## something, and then pretty-format the target and help. Then, +# if there's a line with ##@ something, that gets pretty-printed as a category. +# More info on the usage of ANSI control characters for terminal formatting: +# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters +# More info on the awk command: +# http://linuxcommand.org/lc3_adv_awk.php + +help: ## Display this help. + @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) + + +##@ Test + +test-orm-mysql5: ## Run ORM unit tests on mysql5. + docker-compose -f scripts/orm_docker_compose.yaml up -d + export ORM_DRIVER=mysql + export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" + go test -v github.com/beego/beego/v2/client/orm + docker-compose -f scripts/orm_docker_compose.yaml down + +test-orm-mysql8: ## Run ORM unit tests on mysql8. + docker-compose -f scripts/orm_docker_compose.yaml up -d + export ORM_DRIVER=mysql + export ORM_SOURCE="beego:test@tcp(localhost:23306)/orm_test?charset=utf8" + go test -v github.com/beego/beego/v2/client/orm + docker-compose -f scripts/orm_docker_compose.yaml down + +test-orm-pgsql: ## Run ORM unit tests on postgresql. + docker-compose -f scripts/orm_docker_compose.yaml up -d + export ORM_DRIVER=postgres + export ORM_SOURCE="user=postgres password=postgres dbname=orm_test sslmode=disable" + go test -v github.com/beego/beego/v2/client/orm + docker-compose -f scripts/orm_docker_compose.yaml down + +test-orm-tidb: ## Run ORM unit tests on tidb. + docker-compose -f scripts/orm_docker_compose.yaml up -d + export ORM_DRIVER=tidb + export ORM_SOURCE="memory://test/test" + go test -v github.com/beego/beego/v2/client/orm + docker-compose -f scripts/orm_docker_compose.yaml down + +.PHONY: test-orm-all +test-orm-all: test-orm-mysql5 test-orm-mysql8 test-orm-pgsql test-orm-tidb + +.PHONY: fmt +fmt: + goimports -local "github.com/beego/beego" -w . \ No newline at end of file diff --git a/README.md b/README.md index 5063645c4c..10d37d9848 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,49 @@ -# Beego [![Build Status](https://travis-ci.org/astaxie/beego.svg?branch=master)](https://travis-ci.org/astaxie/beego) [![GoDoc](http://godoc.org/github.com/astaxie/beego?status.svg)](http://godoc.org/github.com/astaxie/beego) [![Foundation](https://img.shields.io/badge/Golang-Foundation-green.svg)](http://golangfoundation.org) [![Go Report Card](https://goreportcard.com/badge/github.com/astaxie/beego)](https://goreportcard.com/report/github.com/astaxie/beego) +# Beego [![Test](https://github.com/beego/beego/actions/workflows/test.yml/badge.svg?branch=develop)](https://github.com/beego/beego/actions/workflows/test.yml) [![Go Report Card](https://goreportcard.com/badge/github.com/beego/beego)](https://goreportcard.com/report/github.com/beego/beego) [![Go Reference](https://pkg.go.dev/badge/github.com/beego/beego/v2.svg)](https://pkg.go.dev/github.com/beego/beego/v2) +Beego is used for rapid development of enterprise application in Go, including RESTful APIs, web apps and backend services. -beego is used for rapid development of RESTful APIs, web apps and backend services in Go. It is inspired by Tornado, Sinatra and Flask. beego has some Go-specific features such as interfaces and struct embedding. - Response time ranking: [web-frameworks](https://github.com/the-benchmarker/web-frameworks). +## Quick Start +- [New Doc Website - unavailable](https://beego.gocn.vip) +- [New Doc Website Backup @flycash](https://doc.meoying.com/en-US/beego/developing/) +- [New Doc Website source code](https://github.com/beego/beego-doc) +- [Old Doc - github](https://github.com/beego/beedoc) +- [Example](https://github.com/beego/beego-example) -###### More info at [beego.me](http://beego.me). +> Kindly remind that sometimes the HTTPS certificate is expired, you may get some NOT SECURE warning -## Quick Start +### Web Application + +#### Create `hello` directory, cd `hello` directory + + mkdir hello + cd hello + +#### Init module + + go mod init #### Download and install - go get github.com/astaxie/beego + go get github.com/beego/beego/v2@latest #### Create file `hello.go` + ```go package main -import "github.com/astaxie/beego" +import "github.com/beego/beego/v2/server/web" -func main(){ - beego.Run() +func main() { + web.Run() } ``` + +#### Download required dependencies + + go mod tidy + #### Build and run go build hello.go @@ -33,31 +53,37 @@ func main(){ Congratulations! You've just built your first **beego** app. -###### Please see [Documentation](http://beego.me/docs) for more. - ## Features * RESTful support -* MVC architecture +* [MVC architecture](https://github.com/beego/beedoc/tree/master/en-US/mvc) * Modularity -* Auto API documents -* Annotation router -* Namespace -* Powerful development tools +* [Auto API documents](https://github.com/beego/beedoc/blob/master/en-US/advantage/docs.md) +* [Annotation router](https://github.com/beego/beedoc/blob/master/en-US/mvc/controller/router.md) +* [Namespace](https://github.com/beego/beedoc/blob/master/en-US/mvc/controller/router.md#namespace) +* [Powerful development tools](https://github.com/beego/bee) * Full stack for Web & API -## Documentation +## Modules -* [English](http://beego.me/docs/intro/) -* [中文文档](http://beego.me/docs/intro/) -* [Русский](http://beego.me/docs/intro/) +* [orm](https://github.com/beego/beedoc/tree/master/en-US/mvc/model) +* [session](https://github.com/beego/beedoc/blob/master/en-US/module/session.md) +* [logs](https://github.com/beego/beedoc/blob/master/en-US/module/logs.md) +* [config](https://github.com/beego/beedoc/blob/master/en-US/module/config.md) +* [cache](https://github.com/beego/beedoc/blob/master/en-US/module/cache.md) +* [context](https://github.com/beego/beedoc/blob/master/en-US/module/context.md) +* [admin](https://github.com/beego/beedoc/blob/master/en-US/module/admin.md) +* [httplib](https://github.com/beego/beedoc/blob/master/en-US/module/httplib.md) +* [task](https://github.com/beego/beedoc/blob/master/en-US/module/task.md) +* [i18n](https://github.com/beego/beedoc/blob/master/en-US/module/i18n.md) ## Community -* [http://beego.me/community](http://beego.me/community) -* Welcome to join us in Slack: [https://beego.slack.com](https://beego.slack.com), you can get invited from [here](https://github.com/beego/beedoc/issues/232) +* Welcome to join us in Slack: [https://beego.slack.com invite](https://join.slack.com/t/beego/shared_invite/zt-fqlfjaxs-_CRmiITCSbEqQG9NeBqXKA), +* QQ Group ID:523992905 +* [Contribution Guide](https://github.com/beego/beedoc/blob/master/en-US/intro/contributing.md). ## License beego source code is licensed under the Apache Licence, Version 2.0 -(http://www.apache.org/licenses/LICENSE-2.0.html). +([https://www.apache.org/licenses/LICENSE-2.0.html](https://www.apache.org/licenses/LICENSE-2.0.html)). diff --git a/admin.go b/admin.go deleted file mode 100644 index 2560650117..0000000000 --- a/admin.go +++ /dev/null @@ -1,415 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" - "os" - "reflect" - "text/template" - "time" - - "github.com/astaxie/beego/grace" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/toolbox" - "github.com/astaxie/beego/utils" -) - -// BeeAdminApp is the default adminApp used by admin module. -var beeAdminApp *adminApp - -// FilterMonitorFunc is default monitor filter when admin module is enable. -// if this func returns, admin module records qps for this request by condition of this function logic. -// usage: -// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { -// if method == "POST" { -// return false -// } -// if t.Nanoseconds() < 100 { -// return false -// } -// if strings.HasPrefix(requestPath, "/astaxie") { -// return false -// } -// return true -// } -// beego.FilterMonitorFunc = MyFilterMonitor. -var FilterMonitorFunc func(string, string, time.Duration, string, int) bool - -func init() { - beeAdminApp = &adminApp{ - routers: make(map[string]http.HandlerFunc), - } - beeAdminApp.Route("/", adminIndex) - beeAdminApp.Route("/qps", qpsIndex) - beeAdminApp.Route("/prof", profIndex) - beeAdminApp.Route("/healthcheck", healthcheck) - beeAdminApp.Route("/task", taskStatus) - beeAdminApp.Route("/listconf", listConf) - FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } -} - -// AdminIndex is the default http.Handler for admin module. -// it matches url pattern "/". -func adminIndex(rw http.ResponseWriter, _ *http.Request) { - execTpl(rw, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) -} - -// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. -// it's registered with url pattern "/qps" in admin module. -func qpsIndex(rw http.ResponseWriter, _ *http.Request) { - data := make(map[interface{}]interface{}) - data["Content"] = toolbox.StatisticsMap.GetMap() - - // do html escape before display path, avoid xss - if content, ok := (data["Content"]).(M); ok { - if resultLists, ok := (content["Data"]).([][]string); ok { - for i := range resultLists { - if len(resultLists[i]) > 0 { - resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) - } - } - } - } - - execTpl(rw, data, qpsTpl, defaultScriptsTpl) -} - -// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. -// it's registered with url pattern "/listconf" in admin module. -func listConf(rw http.ResponseWriter, r *http.Request) { - r.ParseForm() - command := r.Form.Get("command") - if command == "" { - rw.Write([]byte("command not support")) - return - } - - data := make(map[interface{}]interface{}) - switch command { - case "conf": - m := make(M) - list("BConfig", BConfig, m) - m["AppConfigPath"] = appConfigPath - m["AppConfigProvider"] = appConfigProvider - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - tmpl = template.Must(tmpl.Parse(configTpl)) - tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) - - data["Content"] = m - - tmpl.Execute(rw, data) - - case "router": - content := PrintTree() - content["Fields"] = []string{ - "Router Pattern", - "Methods", - "Controller", - } - data["Content"] = content - data["Title"] = "Routers" - execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) - case "filter": - var ( - content = M{ - "Fields": []string{ - "Router Pattern", - "Filter Function", - }, - } - filterTypes = []string{} - filterTypeData = make(M) - ) - - if BeeApp.Handlers.enableFilter { - var filterType string - for k, fr := range map[int]string{ - BeforeStatic: "Before Static", - BeforeRouter: "Before Router", - BeforeExec: "Before Exec", - AfterExec: "After Exec", - FinishRouter: "Finish Router"} { - if bf := BeeApp.Handlers.filters[k]; len(bf) > 0 { - filterType = fr - filterTypes = append(filterTypes, filterType) - resultList := new([][]string) - for _, f := range bf { - var result = []string{ - f.pattern, - utils.GetFuncName(f.filterFunc), - } - *resultList = append(*resultList, result) - } - filterTypeData[filterType] = resultList - } - } - } - - content["Data"] = filterTypeData - content["Methods"] = filterTypes - - data["Content"] = content - data["Title"] = "Filters" - execTpl(rw, data, routerAndFilterTpl, defaultScriptsTpl) - default: - rw.Write([]byte("command not support")) - } -} - -func list(root string, p interface{}, m M) { - pt := reflect.TypeOf(p) - pv := reflect.ValueOf(p) - if pt.Kind() == reflect.Ptr { - pt = pt.Elem() - pv = pv.Elem() - } - for i := 0; i < pv.NumField(); i++ { - var key string - if root == "" { - key = pt.Field(i).Name - } else { - key = root + "." + pt.Field(i).Name - } - if pv.Field(i).Kind() == reflect.Struct { - list(key, pv.Field(i).Interface(), m) - } else { - m[key] = pv.Field(i).Interface() - } - } -} - -// PrintTree prints all registered routers. -func PrintTree() M { - var ( - content = M{} - methods = []string{} - methodsData = make(M) - ) - for method, t := range BeeApp.Handlers.routers { - - resultList := new([][]string) - - printTree(resultList, t) - - methods = append(methods, method) - methodsData[method] = resultList - } - - content["Data"] = methodsData - content["Methods"] = methods - return content -} - -func printTree(resultList *[][]string, t *Tree) { - for _, tr := range t.fixrouters { - printTree(resultList, tr) - } - if t.wildcard != nil { - printTree(resultList, t.wildcard) - } - for _, l := range t.leaves { - if v, ok := l.runObject.(*ControllerInfo); ok { - if v.routerType == routerTypeBeego { - var result = []string{ - v.pattern, - fmt.Sprintf("%s", v.methods), - v.controllerType.String(), - } - *resultList = append(*resultList, result) - } else if v.routerType == routerTypeRESTFul { - var result = []string{ - v.pattern, - fmt.Sprintf("%s", v.methods), - "", - } - *resultList = append(*resultList, result) - } else if v.routerType == routerTypeHandler { - var result = []string{ - v.pattern, - "", - "", - } - *resultList = append(*resultList, result) - } - } - } -} - -// ProfIndex is a http.Handler for showing profile command. -// it's in url pattern "/prof" in admin module. -func profIndex(rw http.ResponseWriter, r *http.Request) { - r.ParseForm() - command := r.Form.Get("command") - if command == "" { - return - } - - var ( - format = r.Form.Get("format") - data = make(map[interface{}]interface{}) - result bytes.Buffer - ) - toolbox.ProcessInput(command, &result) - data["Content"] = result.String() - - if format == "json" && command == "gc summary" { - dataJSON, err := json.Marshal(data) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - - rw.Header().Set("Content-Type", "application/json") - rw.Write(dataJSON) - return - } - - data["Title"] = command - defaultTpl := defaultScriptsTpl - if command == "gc summary" { - defaultTpl = gcAjaxTpl - } - execTpl(rw, data, profillingTpl, defaultTpl) -} - -// Healthcheck is a http.Handler calling health checking and showing the result. -// it's in "/healthcheck" pattern in admin module. -func healthcheck(rw http.ResponseWriter, _ *http.Request) { - var ( - result []string - data = make(map[interface{}]interface{}) - resultList = new([][]string) - content = M{ - "Fields": []string{"Name", "Message", "Status"}, - } - ) - - for name, h := range toolbox.AdminCheckList { - if err := h.Check(); err != nil { - result = []string{ - "error", - name, - err.Error(), - } - } else { - result = []string{ - "success", - name, - "OK", - } - } - *resultList = append(*resultList, result) - } - - content["Data"] = resultList - data["Content"] = content - data["Title"] = "Health Check" - execTpl(rw, data, healthCheckTpl, defaultScriptsTpl) -} - -// TaskStatus is a http.Handler with running task status (task name, status and the last execution). -// it's in "/task" pattern in admin module. -func taskStatus(rw http.ResponseWriter, req *http.Request) { - data := make(map[interface{}]interface{}) - - // Run Task - req.ParseForm() - taskname := req.Form.Get("taskname") - if taskname != "" { - if t, ok := toolbox.AdminTaskList[taskname]; ok { - if err := t.Run(); err != nil { - data["Message"] = []string{"error", fmt.Sprintf("%s", err)} - } - data["Message"] = []string{"success", fmt.Sprintf("%s run success,Now the Status is
%s", taskname, t.GetStatus())} - } else { - data["Message"] = []string{"warning", fmt.Sprintf("there's no task which named: %s", taskname)} - } - } - - // List Tasks - content := make(M) - resultList := new([][]string) - var fields = []string{ - "Task Name", - "Task Spec", - "Task Status", - "Last Time", - "", - } - for tname, tk := range toolbox.AdminTaskList { - result := []string{ - tname, - tk.GetSpec(), - tk.GetStatus(), - tk.GetPrev().String(), - } - *resultList = append(*resultList, result) - } - - content["Fields"] = fields - content["Data"] = resultList - data["Content"] = content - data["Title"] = "Tasks" - execTpl(rw, data, tasksTpl, defaultScriptsTpl) -} - -func execTpl(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { - tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) - for _, tpl := range tpls { - tmpl = template.Must(tmpl.Parse(tpl)) - } - tmpl.Execute(rw, data) -} - -// adminApp is an http.HandlerFunc map used as beeAdminApp. -type adminApp struct { - routers map[string]http.HandlerFunc -} - -// Route adds http.HandlerFunc to adminApp with url pattern. -func (admin *adminApp) Route(pattern string, f http.HandlerFunc) { - admin.routers[pattern] = f -} - -// Run adminApp http server. -// Its addr is defined in configuration file as adminhttpaddr and adminhttpport. -func (admin *adminApp) Run() { - if len(toolbox.AdminTaskList) > 0 { - toolbox.StartTask() - } - addr := BConfig.Listen.AdminAddr - - if BConfig.Listen.AdminPort != 0 { - addr = fmt.Sprintf("%s:%d", BConfig.Listen.AdminAddr, BConfig.Listen.AdminPort) - } - for p, f := range admin.routers { - http.Handle(p, f) - } - logs.Info("Admin server Running on %s", addr) - - var err error - if BConfig.Listen.Graceful { - err = grace.ListenAndServe(addr, nil) - } else { - err = http.ListenAndServe(addr, nil) - } - if err != nil { - logs.Critical("Admin ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) - } -} diff --git a/admin_test.go b/admin_test.go deleted file mode 100644 index 539837cfe0..0000000000 --- a/admin_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package beego - -import ( - "fmt" - "testing" -) - -func TestList_01(t *testing.T) { - m := make(M) - list("BConfig", BConfig, m) - t.Log(m) - om := oldMap() - for k, v := range om { - if fmt.Sprint(m[k]) != fmt.Sprint(v) { - t.Log(k, "old-key", v, "new-key", m[k]) - t.FailNow() - } - } -} - -func oldMap() M { - m := make(M) - m["BConfig.AppName"] = BConfig.AppName - m["BConfig.RunMode"] = BConfig.RunMode - m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive - m["BConfig.ServerName"] = BConfig.ServerName - m["BConfig.RecoverPanic"] = BConfig.RecoverPanic - m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody - m["BConfig.EnableGzip"] = BConfig.EnableGzip - m["BConfig.MaxMemory"] = BConfig.MaxMemory - m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow - m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful - m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut - m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 - m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP - m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr - m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort - m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS - m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr - m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort - m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile - m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile - m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin - m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr - m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort - m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi - m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo - m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender - m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs - m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName - m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator - m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex - m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir - m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip - m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft - m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight - m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath - m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF - m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire - m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn - m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider - m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName - m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime - m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig - m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime - m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie - m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain - m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly - m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs - m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs - m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat - m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum - m["BConfig.Log.Outputs"] = BConfig.Log.Outputs - return m -} diff --git a/app.go b/app.go deleted file mode 100644 index d9e85e9b63..0000000000 --- a/app.go +++ /dev/null @@ -1,497 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "crypto/tls" - "crypto/x509" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/http/fcgi" - "os" - "path" - "strings" - "time" - - "github.com/astaxie/beego/grace" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" - "golang.org/x/crypto/acme/autocert" -) - -var ( - // BeeApp is an application instance - BeeApp *App -) - -func init() { - // create beego application - BeeApp = NewApp() -} - -// App defines beego application with a new PatternServeMux. -type App struct { - Handlers *ControllerRegister - Server *http.Server -} - -// NewApp returns a new beego application. -func NewApp() *App { - cr := NewControllerRegister() - app := &App{Handlers: cr, Server: &http.Server{}} - return app -} - -// MiddleWare function for http.Handler -type MiddleWare func(http.Handler) http.Handler - -// Run beego application. -func (app *App) Run(mws ...MiddleWare) { - addr := BConfig.Listen.HTTPAddr - - if BConfig.Listen.HTTPPort != 0 { - addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPAddr, BConfig.Listen.HTTPPort) - } - - var ( - err error - l net.Listener - endRunning = make(chan bool, 1) - ) - - // run cgi server - if BConfig.Listen.EnableFcgi { - if BConfig.Listen.EnableStdIo { - if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O - logs.Info("Use FCGI via standard I/O") - } else { - logs.Critical("Cannot use FCGI via standard I/O", err) - } - return - } - if BConfig.Listen.HTTPPort == 0 { - // remove the Socket file before start - if utils.FileExists(addr) { - os.Remove(addr) - } - l, err = net.Listen("unix", addr) - } else { - l, err = net.Listen("tcp", addr) - } - if err != nil { - logs.Critical("Listen: ", err) - } - if err = fcgi.Serve(l, app.Handlers); err != nil { - logs.Critical("fcgi.Serve: ", err) - } - return - } - - app.Server.Handler = app.Handlers - for i := len(mws) - 1; i >= 0; i-- { - if mws[i] == nil { - continue - } - app.Server.Handler = mws[i](app.Server.Handler) - } - app.Server.ReadTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second - app.Server.WriteTimeout = time.Duration(BConfig.Listen.ServerTimeOut) * time.Second - app.Server.ErrorLog = logs.GetLogger("HTTP") - - // run graceful mode - if BConfig.Listen.Graceful { - httpsAddr := BConfig.Listen.HTTPSAddr - app.Server.Addr = httpsAddr - if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { - go func() { - time.Sleep(1000 * time.Microsecond) - if BConfig.Listen.HTTPSPort != 0 { - httpsAddr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) - app.Server.Addr = httpsAddr - } - server := grace.NewServer(httpsAddr, app.Handlers) - server.Server.ReadTimeout = app.Server.ReadTimeout - server.Server.WriteTimeout = app.Server.WriteTimeout - if BConfig.Listen.EnableMutualHTTPS { - if err := server.ListenAndServeMutualTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile, BConfig.Listen.TrustCaFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - } else { - if BConfig.Listen.AutoTLS { - m := autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), - Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), - } - app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} - BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" - } - if err := server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - } - }() - } - if BConfig.Listen.EnableHTTP { - go func() { - server := grace.NewServer(addr, app.Handlers) - server.Server.ReadTimeout = app.Server.ReadTimeout - server.Server.WriteTimeout = app.Server.WriteTimeout - if BConfig.Listen.ListenTCP4 { - server.Network = "tcp4" - } - if err := server.ListenAndServe(); err != nil { - logs.Critical("ListenAndServe: ", err, fmt.Sprintf("%d", os.Getpid())) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - }() - } - <-endRunning - return - } - - // run normal mode - if BConfig.Listen.EnableHTTPS || BConfig.Listen.EnableMutualHTTPS { - go func() { - time.Sleep(1000 * time.Microsecond) - if BConfig.Listen.HTTPSPort != 0 { - app.Server.Addr = fmt.Sprintf("%s:%d", BConfig.Listen.HTTPSAddr, BConfig.Listen.HTTPSPort) - } else if BConfig.Listen.EnableHTTP { - logs.Info("Start https server error, conflict with http. Please reset https port") - return - } - logs.Info("https server Running on https://%s", app.Server.Addr) - if BConfig.Listen.AutoTLS { - m := autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: autocert.HostWhitelist(BConfig.Listen.Domains...), - Cache: autocert.DirCache(BConfig.Listen.TLSCacheDir), - } - app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} - BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile = "", "" - } else if BConfig.Listen.EnableMutualHTTPS { - pool := x509.NewCertPool() - data, err := ioutil.ReadFile(BConfig.Listen.TrustCaFile) - if err != nil { - logs.Info("MutualHTTPS should provide TrustCaFile") - return - } - pool.AppendCertsFromPEM(data) - app.Server.TLSConfig = &tls.Config{ - ClientCAs: pool, - ClientAuth: tls.RequireAndVerifyClientCert, - } - } - if err := app.Server.ListenAndServeTLS(BConfig.Listen.HTTPSCertFile, BConfig.Listen.HTTPSKeyFile); err != nil { - logs.Critical("ListenAndServeTLS: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - }() - - } - if BConfig.Listen.EnableHTTP { - go func() { - app.Server.Addr = addr - logs.Info("http server Running on http://%s", app.Server.Addr) - if BConfig.Listen.ListenTCP4 { - ln, err := net.Listen("tcp4", app.Server.Addr) - if err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - if err = app.Server.Serve(ln); err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - return - } - } else { - if err := app.Server.ListenAndServe(); err != nil { - logs.Critical("ListenAndServe: ", err) - time.Sleep(100 * time.Microsecond) - endRunning <- true - } - } - }() - } - <-endRunning -} - -// Router adds a patterned controller handler to BeeApp. -// it's an alias method of App.Router. -// usage: -// simple router -// beego.Router("/admin", &admin.UserController{}) -// beego.Router("/admin/index", &admin.ArticleController{}) -// -// regex router -// -// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) -// -// custom rules -// beego.Router("/api/list",&RestController{},"*:ListFood") -// beego.Router("/api/create",&RestController{},"post:CreateFood") -// beego.Router("/api/update",&RestController{},"put:UpdateFood") -// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") -func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App { - BeeApp.Handlers.Add(rootpath, c, mappingMethods...) - return BeeApp -} - -// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful -// in web applications that inherit most routes from a base webapp via the underscore -// import, and aim to overwrite only certain paths. -// The method parameter can be empty or "*" for all HTTP methods, or a particular -// method type (e.g. "GET" or "POST") for selective removal. -// -// Usage (replace "GET" with "*" for all methods): -// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") -// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") -func UnregisterFixedRoute(fixedRoute string, method string) *App { - subPaths := splitPath(fixedRoute) - if method == "" || method == "*" { - for m := range HTTPMETHOD { - if _, ok := BeeApp.Handlers.routers[m]; !ok { - continue - } - if BeeApp.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { - findAndRemoveSingleTree(BeeApp.Handlers.routers[m]) - continue - } - findAndRemoveTree(subPaths, BeeApp.Handlers.routers[m], m) - } - return BeeApp - } - // Single HTTP method - um := strings.ToUpper(method) - if _, ok := BeeApp.Handlers.routers[um]; ok { - if BeeApp.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { - findAndRemoveSingleTree(BeeApp.Handlers.routers[um]) - return BeeApp - } - findAndRemoveTree(subPaths, BeeApp.Handlers.routers[um], um) - } - return BeeApp -} - -func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { - for i := range entryPointTree.fixrouters { - if entryPointTree.fixrouters[i].prefix == paths[0] { - if len(paths) == 1 { - if len(entryPointTree.fixrouters[i].fixrouters) > 0 { - // If the route had children subtrees, remove just the functional leaf, - // to allow children to function as before - if len(entryPointTree.fixrouters[i].leaves) > 0 { - entryPointTree.fixrouters[i].leaves[0] = nil - entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:] - } - } else { - // Remove the *Tree from the fixrouters slice - entryPointTree.fixrouters[i] = nil - - if i == len(entryPointTree.fixrouters)-1 { - entryPointTree.fixrouters = entryPointTree.fixrouters[:i] - } else { - entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...) - } - } - return - } - findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method) - } - } -} - -func findAndRemoveSingleTree(entryPointTree *Tree) { - if entryPointTree == nil { - return - } - if len(entryPointTree.fixrouters) > 0 { - // If the route had children subtrees, remove just the functional leaf, - // to allow children to function as before - if len(entryPointTree.leaves) > 0 { - entryPointTree.leaves[0] = nil - entryPointTree.leaves = entryPointTree.leaves[1:] - } - } -} - -// Include will generate router file in the router/xxx.go from the controller's comments -// usage: -// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) -// type BankAccount struct{ -// beego.Controller -// } -// -// register the function -// func (b *BankAccount)Mapping(){ -// b.Mapping("ShowAccount" , b.ShowAccount) -// b.Mapping("ModifyAccount", b.ModifyAccount) -//} -// -// //@router /account/:id [get] -// func (b *BankAccount) ShowAccount(){ -// //logic -// } -// -// -// //@router /account/:id [post] -// func (b *BankAccount) ModifyAccount(){ -// //logic -// } -// -// the comments @router url methodlist -// url support all the function Router's pattern -// methodlist [get post head put delete options *] -func Include(cList ...ControllerInterface) *App { - BeeApp.Handlers.Include(cList...) - return BeeApp -} - -// RESTRouter adds a restful controller handler to BeeApp. -// its' controller implements beego.ControllerInterface and -// defines a param "pattern/:objectId" to visit each resource. -func RESTRouter(rootpath string, c ControllerInterface) *App { - Router(rootpath, c) - Router(path.Join(rootpath, ":objectId"), c) - return BeeApp -} - -// AutoRouter adds defined controller handler to BeeApp. -// it's same to App.AutoRouter. -// if beego.AddAuto(&MainContorlller{}) and MainController has methods List and Page, -// visit the url /main/list to exec List function or /main/page to exec Page function. -func AutoRouter(c ControllerInterface) *App { - BeeApp.Handlers.AddAuto(c) - return BeeApp -} - -// AutoPrefix adds controller handler to BeeApp with prefix. -// it's same to App.AutoRouterWithPrefix. -// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page, -// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. -func AutoPrefix(prefix string, c ControllerInterface) *App { - BeeApp.Handlers.AddAutoPrefix(prefix, c) - return BeeApp -} - -// Get used to register router for Get method -// usage: -// beego.Get("/", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Get(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Get(rootpath, f) - return BeeApp -} - -// Post used to register router for Post method -// usage: -// beego.Post("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Post(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Post(rootpath, f) - return BeeApp -} - -// Delete used to register router for Delete method -// usage: -// beego.Delete("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Delete(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Delete(rootpath, f) - return BeeApp -} - -// Put used to register router for Put method -// usage: -// beego.Put("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Put(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Put(rootpath, f) - return BeeApp -} - -// Head used to register router for Head method -// usage: -// beego.Head("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Head(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Head(rootpath, f) - return BeeApp -} - -// Options used to register router for Options method -// usage: -// beego.Options("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Options(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Options(rootpath, f) - return BeeApp -} - -// Patch used to register router for Patch method -// usage: -// beego.Patch("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Patch(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Patch(rootpath, f) - return BeeApp -} - -// Any used to register router for all methods -// usage: -// beego.Any("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func Any(rootpath string, f FilterFunc) *App { - BeeApp.Handlers.Any(rootpath, f) - return BeeApp -} - -// Handler used to register a Handler router -// usage: -// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { -// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) -// })) -func Handler(rootpath string, h http.Handler, options ...interface{}) *App { - BeeApp.Handlers.Handler(rootpath, h, options...) - return BeeApp -} - -// InsertFilter adds a FilterFunc with pattern condition and action constant. -// The pos means action constant including -// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. -// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) -func InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) *App { - BeeApp.Handlers.InsertFilter(pattern, pos, filter, params...) - return BeeApp -} diff --git a/build_info.go b/build_info.go new file mode 100644 index 0000000000..db898a23a2 --- /dev/null +++ b/build_info.go @@ -0,0 +1,32 @@ +// Copyright 2020 astaxie +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package beego + +var ( + BuildVersion string + BuildGitRevision string + BuildStatus string + BuildTag string + BuildTime string + + GoVersion string + + GitBranch string +) + +const ( + // VERSION represent beego web framework version. + VERSION = "2.1.2" +) diff --git a/cache/README.md b/cache/README.md deleted file mode 100644 index b467760afe..0000000000 --- a/cache/README.md +++ /dev/null @@ -1,59 +0,0 @@ -## cache -cache is a Go cache manager. It can use many cache adapters. The repo is inspired by `database/sql` . - - -## How to install? - - go get github.com/astaxie/beego/cache - - -## What adapters are supported? - -As of now this cache support memory, Memcache and Redis. - - -## How to use it? - -First you must import it - - import ( - "github.com/astaxie/beego/cache" - ) - -Then init a Cache (example with memory adapter) - - bm, err := cache.NewCache("memory", `{"interval":60}`) - -Use it like this: - - bm.Put("astaxie", 1, 10 * time.Second) - bm.Get("astaxie") - bm.IsExist("astaxie") - bm.Delete("astaxie") - - -## Memory adapter - -Configure memory adapter like this: - - {"interval":60} - -interval means the gc time. The cache will check at each time interval, whether item has expired. - - -## Memcache adapter - -Memcache adapter use the [gomemcache](http://github.com/bradfitz/gomemcache) client. - -Configure like this: - - {"conn":"127.0.0.1:11211"} - - -## Redis adapter - -Redis adapter use the [redigo](http://github.com/gomodule/redigo) client. - -Configure like this: - - {"conn":":6039"} diff --git a/cache/cache_test.go b/cache/cache_test.go deleted file mode 100644 index 1e71d075a6..0000000000 --- a/cache/cache_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "os" - "testing" - "time" -) - -func TestCache(t *testing.T) { - bm, err := NewCache("memory", `{"interval":20}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - time.Sleep(30 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test GetMulti - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } -} - -func TestFileCache(t *testing.T) { - bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v := bm.Get("astaxie"); v.(int) != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - if v := bm.Get("astaxie"); v.(string) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if vv[0].(string) != "author" { - t.Error("GetMulti ERROR") - } - if vv[1].(string) != "author1" { - t.Error("GetMulti ERROR") - } - - os.RemoveAll("cache") -} diff --git a/cache/conv_test.go b/cache/conv_test.go deleted file mode 100644 index b90e224a36..0000000000 --- a/cache/conv_test.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "testing" -) - -func TestGetString(t *testing.T) { - var t1 = "test1" - if "test1" != GetString(t1) { - t.Error("get string from string error") - } - var t2 = []byte("test2") - if "test2" != GetString(t2) { - t.Error("get string from byte array error") - } - var t3 = 1 - if "1" != GetString(t3) { - t.Error("get string from int error") - } - var t4 int64 = 1 - if "1" != GetString(t4) { - t.Error("get string from int64 error") - } - var t5 = 1.1 - if "1.1" != GetString(t5) { - t.Error("get string from float64 error") - } - - if "" != GetString(nil) { - t.Error("get string from nil error") - } -} - -func TestGetInt(t *testing.T) { - var t1 = 1 - if 1 != GetInt(t1) { - t.Error("get int from int error") - } - var t2 int32 = 32 - if 32 != GetInt(t2) { - t.Error("get int from int32 error") - } - var t3 int64 = 64 - if 64 != GetInt(t3) { - t.Error("get int from int64 error") - } - var t4 = "128" - if 128 != GetInt(t4) { - t.Error("get int from num string error") - } - if 0 != GetInt(nil) { - t.Error("get int from nil error") - } -} - -func TestGetInt64(t *testing.T) { - var i int64 = 1 - var t1 = 1 - if i != GetInt64(t1) { - t.Error("get int64 from int error") - } - var t2 int32 = 1 - if i != GetInt64(t2) { - t.Error("get int64 from int32 error") - } - var t3 int64 = 1 - if i != GetInt64(t3) { - t.Error("get int64 from int64 error") - } - var t4 = "1" - if i != GetInt64(t4) { - t.Error("get int64 from num string error") - } - if 0 != GetInt64(nil) { - t.Error("get int64 from nil") - } -} - -func TestGetFloat64(t *testing.T) { - var f = 1.11 - var t1 float32 = 1.11 - if f != GetFloat64(t1) { - t.Error("get float64 from float32 error") - } - var t2 = 1.11 - if f != GetFloat64(t2) { - t.Error("get float64 from float64 error") - } - var t3 = "1.11" - if f != GetFloat64(t3) { - t.Error("get float64 from string error") - } - - var f2 float64 = 1 - var t4 = 1 - if f2 != GetFloat64(t4) { - t.Error("get float64 from int error") - } - - if 0 != GetFloat64(nil) { - t.Error("get float64 from nil error") - } -} - -func TestGetBool(t *testing.T) { - var t1 = true - if !GetBool(t1) { - t.Error("get bool from bool error") - } - var t2 = "true" - if !GetBool(t2) { - t.Error("get bool from string error") - } - if GetBool(nil) { - t.Error("get bool from nil error") - } -} - -func byteArrayEquals(a []byte, b []byte) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} diff --git a/cache/file.go b/cache/file.go deleted file mode 100644 index 6f12d3eee9..0000000000 --- a/cache/file.go +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "bytes" - "crypto/md5" - "encoding/gob" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "reflect" - "strconv" - "time" -) - -// FileCacheItem is basic unit of file cache adapter. -// it contains data and expire time. -type FileCacheItem struct { - Data interface{} - Lastaccess time.Time - Expired time.Time -} - -// FileCache Config -var ( - FileCachePath = "cache" // cache directory - FileCacheFileSuffix = ".bin" // cache file suffix - FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. - FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. -) - -// FileCache is cache adapter for file storage. -type FileCache struct { - CachePath string - FileSuffix string - DirectoryLevel int - EmbedExpiry int -} - -// NewFileCache Create new file cache with no config. -// the level and expiry need set in method StartAndGC as config string. -func NewFileCache() Cache { - // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} - return &FileCache{} -} - -// StartAndGC will start and begin gc for file cache. -// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} -func (fc *FileCache) StartAndGC(config string) error { - - cfg := make(map[string]string) - err := json.Unmarshal([]byte(config), &cfg) - if err != nil { - return err - } - if _, ok := cfg["CachePath"]; !ok { - cfg["CachePath"] = FileCachePath - } - if _, ok := cfg["FileSuffix"]; !ok { - cfg["FileSuffix"] = FileCacheFileSuffix - } - if _, ok := cfg["DirectoryLevel"]; !ok { - cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) - } - if _, ok := cfg["EmbedExpiry"]; !ok { - cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) - } - fc.CachePath = cfg["CachePath"] - fc.FileSuffix = cfg["FileSuffix"] - fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) - fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) - - fc.Init() - return nil -} - -// Init will make new dir for file cache if not exist. -func (fc *FileCache) Init() { - if ok, _ := exists(fc.CachePath); !ok { // todo : error handle - _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle - } -} - -// get cached file name. it's md5 encoded. -func (fc *FileCache) getCacheFileName(key string) string { - m := md5.New() - io.WriteString(m, key) - keyMd5 := hex.EncodeToString(m.Sum(nil)) - cachePath := fc.CachePath - switch fc.DirectoryLevel { - case 2: - cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) - case 1: - cachePath = filepath.Join(cachePath, keyMd5[0:2]) - } - - if ok, _ := exists(cachePath); !ok { // todo : error handle - _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle - } - - return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) -} - -// Get value from file cache. -// if non-exist or expired, return empty string. -func (fc *FileCache) Get(key string) interface{} { - fileData, err := FileGetContents(fc.getCacheFileName(key)) - if err != nil { - return "" - } - var to FileCacheItem - GobDecode(fileData, &to) - if to.Expired.Before(time.Now()) { - return "" - } - return to.Data -} - -// GetMulti gets values from file cache. -// if non-exist or expired, return empty string. -func (fc *FileCache) GetMulti(keys []string) []interface{} { - var rc []interface{} - for _, key := range keys { - rc = append(rc, fc.Get(key)) - } - return rc -} - -// Put value into file cache. -// timeout means how long to keep this file, unit of ms. -// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. -func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { - gob.Register(val) - - item := FileCacheItem{Data: val} - if timeout == time.Duration(fc.EmbedExpiry) { - item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years - } else { - item.Expired = time.Now().Add(timeout) - } - item.Lastaccess = time.Now() - data, err := GobEncode(item) - if err != nil { - return err - } - return FilePutContents(fc.getCacheFileName(key), data) -} - -// Delete file cache value. -func (fc *FileCache) Delete(key string) error { - filename := fc.getCacheFileName(key) - if ok, _ := exists(filename); ok { - return os.Remove(filename) - } - return nil -} - -// Incr will increase cached int value. -// fc value is saving forever unless Delete. -func (fc *FileCache) Incr(key string) error { - data := fc.Get(key) - var incr int - if reflect.TypeOf(data).Name() != "int" { - incr = 0 - } else { - incr = data.(int) + 1 - } - fc.Put(key, incr, time.Duration(fc.EmbedExpiry)) - return nil -} - -// Decr will decrease cached int value. -func (fc *FileCache) Decr(key string) error { - data := fc.Get(key) - var decr int - if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { - decr = 0 - } else { - decr = data.(int) - 1 - } - fc.Put(key, decr, time.Duration(fc.EmbedExpiry)) - return nil -} - -// IsExist check value is exist. -func (fc *FileCache) IsExist(key string) bool { - ret, _ := exists(fc.getCacheFileName(key)) - return ret -} - -// ClearAll will clean cached files. -// not implemented. -func (fc *FileCache) ClearAll() error { - return nil -} - -// check file exist. -func exists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return false, err -} - -// FileGetContents Get bytes to file. -// if non-exist, create this file. -func FileGetContents(filename string) (data []byte, e error) { - return ioutil.ReadFile(filename) -} - -// FilePutContents Put bytes to file. -// if non-exist, create this file. -func FilePutContents(filename string, content []byte) error { - return ioutil.WriteFile(filename, content, os.ModePerm) -} - -// GobEncode Gob encodes file cache item. -func GobEncode(data interface{}) ([]byte, error) { - buf := bytes.NewBuffer(nil) - enc := gob.NewEncoder(buf) - err := enc.Encode(data) - if err != nil { - return nil, err - } - return buf.Bytes(), err -} - -// GobDecode Gob decodes file cache item. -func GobDecode(data []byte, to *FileCacheItem) error { - buf := bytes.NewBuffer(data) - dec := gob.NewDecoder(buf) - return dec.Decode(&to) -} - -func init() { - Register("file", NewFileCache) -} diff --git a/cache/memcache/memcache.go b/cache/memcache/memcache.go deleted file mode 100644 index 19116bfac3..0000000000 --- a/cache/memcache/memcache.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package memcache for cache provider -// -// depend on github.com/bradfitz/gomemcache/memcache -// -// go install github.com/bradfitz/gomemcache/memcache -// -// Usage: -// import( -// _ "github.com/astaxie/beego/cache/memcache" -// "github.com/astaxie/beego/cache" -// ) -// -// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) -// -// more docs http://beego.me/docs/module/cache.md -package memcache - -import ( - "encoding/json" - "errors" - "strings" - "time" - - "github.com/astaxie/beego/cache" - "github.com/bradfitz/gomemcache/memcache" -) - -// Cache Memcache adapter. -type Cache struct { - conn *memcache.Client - conninfo []string -} - -// NewMemCache create new memcache adapter. -func NewMemCache() cache.Cache { - return &Cache{} -} - -// Get get value from memcache. -func (rc *Cache) Get(key string) interface{} { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - if item, err := rc.conn.Get(key); err == nil { - return item.Value - } - return nil -} - -// GetMulti get value from memcache. -func (rc *Cache) GetMulti(keys []string) []interface{} { - size := len(keys) - var rv []interface{} - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv - } - } - mv, err := rc.conn.GetMulti(keys) - if err == nil { - for _, v := range mv { - rv = append(rv, v.Value) - } - return rv - } - for i := 0; i < size; i++ { - rv = append(rv, err) - } - return rv -} - -// Put put value to memcache. -func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} - if v, ok := val.([]byte); ok { - item.Value = v - } else if str, ok := val.(string); ok { - item.Value = []byte(str) - } else { - return errors.New("val only support string and []byte") - } - return rc.conn.Set(&item) -} - -// Delete delete value in memcache. -func (rc *Cache) Delete(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.Delete(key) -} - -// Incr increase counter. -func (rc *Cache) Incr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Increment(key, 1) - return err -} - -// Decr decrease counter. -func (rc *Cache) Decr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Decrement(key, 1) - return err -} - -// IsExist check value exists in memcache. -func (rc *Cache) IsExist(key string) bool { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false - } - } - _, err := rc.conn.Get(key) - return err == nil -} - -// ClearAll clear all cached in memcache. -func (rc *Cache) ClearAll() error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return rc.conn.FlushAll() -} - -// StartAndGC start memcache adapter. -// config string is like {"conn":"connection info"}. -// if connecting error, return. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil -} - -// connect to memcache and keep the connection. -func (rc *Cache) connectInit() error { - rc.conn = memcache.New(rc.conninfo...) - return nil -} - -func init() { - cache.Register("memcache", NewMemCache) -} diff --git a/cache/memcache/memcache_test.go b/cache/memcache/memcache_test.go deleted file mode 100644 index d9129b6958..0000000000 --- a/cache/memcache/memcache_test.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memcache - -import ( - _ "github.com/bradfitz/gomemcache/memcache" - - "strconv" - "testing" - "time" - - "github.com/astaxie/beego/cache" -) - -func TestMemcacheCache(t *testing.T) { - bm, err := cache.NewCache("memcache", `{"conn": "127.0.0.1:11211"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - time.Sleep(11 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", "1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v, err := strconv.Atoi(string(bm.Get("astaxie").([]byte))); err != nil || v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v := bm.Get("astaxie").([]byte); string(v) != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { - t.Error("GetMulti ERROR") - } - if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { - t.Error("GetMulti ERROR") - } - - // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } -} diff --git a/cache/memory.go b/cache/memory.go deleted file mode 100644 index e5b562f003..0000000000 --- a/cache/memory.go +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "encoding/json" - "errors" - "sync" - "time" -) - -var ( - // DefaultEvery means the clock time of recycling the expired cache items in memory. - DefaultEvery = 60 // 1 minute -) - -// MemoryItem store memory cache item. -type MemoryItem struct { - val interface{} - createdTime time.Time - lifespan time.Duration -} - -func (mi *MemoryItem) isExpire() bool { - // 0 means forever - if mi.lifespan == 0 { - return false - } - return time.Now().Sub(mi.createdTime) > mi.lifespan -} - -// MemoryCache is Memory cache adapter. -// it contains a RW locker for safe map storage. -type MemoryCache struct { - sync.RWMutex - dur time.Duration - items map[string]*MemoryItem - Every int // run an expiration check Every clock time -} - -// NewMemoryCache returns a new MemoryCache. -func NewMemoryCache() Cache { - cache := MemoryCache{items: make(map[string]*MemoryItem)} - return &cache -} - -// Get cache from memory. -// if non-existed or expired, return nil. -func (bc *MemoryCache) Get(name string) interface{} { - bc.RLock() - defer bc.RUnlock() - if itm, ok := bc.items[name]; ok { - if itm.isExpire() { - return nil - } - return itm.val - } - return nil -} - -// GetMulti gets caches from memory. -// if non-existed or expired, return nil. -func (bc *MemoryCache) GetMulti(names []string) []interface{} { - var rc []interface{} - for _, name := range names { - rc = append(rc, bc.Get(name)) - } - return rc -} - -// Put cache to memory. -// if lifespan is 0, it will be forever till restart. -func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { - bc.Lock() - defer bc.Unlock() - bc.items[name] = &MemoryItem{ - val: value, - createdTime: time.Now(), - lifespan: lifespan, - } - return nil -} - -// Delete cache in memory. -func (bc *MemoryCache) Delete(name string) error { - bc.Lock() - defer bc.Unlock() - if _, ok := bc.items[name]; !ok { - return errors.New("key not exist") - } - delete(bc.items, name) - if _, ok := bc.items[name]; ok { - return errors.New("delete key error") - } - return nil -} - -// Incr increase cache counter in memory. -// it supports int,int32,int64,uint,uint32,uint64. -func (bc *MemoryCache) Incr(key string) error { - bc.RLock() - defer bc.RUnlock() - itm, ok := bc.items[key] - if !ok { - return errors.New("key not exist") - } - switch val := itm.val.(type) { - case int: - itm.val = val + 1 - case int32: - itm.val = val + 1 - case int64: - itm.val = val + 1 - case uint: - itm.val = val + 1 - case uint32: - itm.val = val + 1 - case uint64: - itm.val = val + 1 - default: - return errors.New("item val is not (u)int (u)int32 (u)int64") - } - return nil -} - -// Decr decrease counter in memory. -func (bc *MemoryCache) Decr(key string) error { - bc.RLock() - defer bc.RUnlock() - itm, ok := bc.items[key] - if !ok { - return errors.New("key not exist") - } - switch val := itm.val.(type) { - case int: - itm.val = val - 1 - case int64: - itm.val = val - 1 - case int32: - itm.val = val - 1 - case uint: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint32: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - case uint64: - if val > 0 { - itm.val = val - 1 - } else { - return errors.New("item val is less than 0") - } - default: - return errors.New("item val is not int int64 int32") - } - return nil -} - -// IsExist check cache exist in memory. -func (bc *MemoryCache) IsExist(name string) bool { - bc.RLock() - defer bc.RUnlock() - if v, ok := bc.items[name]; ok { - return !v.isExpire() - } - return false -} - -// ClearAll will delete all cache in memory. -func (bc *MemoryCache) ClearAll() error { - bc.Lock() - defer bc.Unlock() - bc.items = make(map[string]*MemoryItem) - return nil -} - -// StartAndGC start memory cache. it will check expiration in every clock time. -func (bc *MemoryCache) StartAndGC(config string) error { - var cf map[string]int - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["interval"]; !ok { - cf = make(map[string]int) - cf["interval"] = DefaultEvery - } - dur := time.Duration(cf["interval"]) * time.Second - bc.Every = cf["interval"] - bc.dur = dur - go bc.vacuum() - return nil -} - -// check expiration. -func (bc *MemoryCache) vacuum() { - bc.RLock() - every := bc.Every - bc.RUnlock() - - if every < 1 { - return - } - for { - <-time.After(bc.dur) - if bc.items == nil { - return - } - if keys := bc.expiredKeys(); len(keys) != 0 { - bc.clearItems(keys) - } - } -} - -// expiredKeys returns key list which are expired. -func (bc *MemoryCache) expiredKeys() (keys []string) { - bc.RLock() - defer bc.RUnlock() - for key, itm := range bc.items { - if itm.isExpire() { - keys = append(keys, key) - } - } - return -} - -// clearItems removes all the items which key in keys. -func (bc *MemoryCache) clearItems(keys []string) { - bc.Lock() - defer bc.Unlock() - for _, key := range keys { - delete(bc.items, key) - } -} - -func init() { - Register("memory", NewMemoryCache) -} diff --git a/cache/redis/redis.go b/cache/redis/redis.go deleted file mode 100644 index 372dd48b3b..0000000000 --- a/cache/redis/redis.go +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for cache provider -// -// depend on github.com/gomodule/redigo/redis -// -// go install github.com/gomodule/redigo/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/cache/redis" -// "github.com/astaxie/beego/cache" -// ) -// -// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) -// -// more docs http://beego.me/docs/module/cache.md -package redis - -import ( - "encoding/json" - "errors" - "fmt" - "strconv" - "time" - - "github.com/gomodule/redigo/redis" - - "github.com/astaxie/beego/cache" - "strings" -) - -var ( - // DefaultKey the collection name of redis for cache adapter. - DefaultKey = "beecacheRedis" -) - -// Cache is Redis cache adapter. -type Cache struct { - p *redis.Pool // redis connection pool - conninfo string - dbNum int - key string - password string - maxIdle int -} - -// NewRedisCache create new redis cache with default collection name. -func NewRedisCache() cache.Cache { - return &Cache{key: DefaultKey} -} - -// actually do the redis cmds, args[0] must be the key name. -func (rc *Cache) do(commandName string, args ...interface{}) (reply interface{}, err error) { - if len(args) < 1 { - return nil, errors.New("missing required arguments") - } - args[0] = rc.associate(args[0]) - c := rc.p.Get() - defer c.Close() - - return c.Do(commandName, args...) -} - -// associate with config key. -func (rc *Cache) associate(originKey interface{}) string { - return fmt.Sprintf("%s:%s", rc.key, originKey) -} - -// Get cache from redis. -func (rc *Cache) Get(key string) interface{} { - if v, err := rc.do("GET", key); err == nil { - return v - } - return nil -} - -// GetMulti get cache from redis. -func (rc *Cache) GetMulti(keys []string) []interface{} { - c := rc.p.Get() - defer c.Close() - var args []interface{} - for _, key := range keys { - args = append(args, rc.associate(key)) - } - values, err := redis.Values(c.Do("MGET", args...)) - if err != nil { - return nil - } - return values -} - -// Put put cache to redis. -func (rc *Cache) Put(key string, val interface{}, timeout time.Duration) error { - _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) - return err -} - -// Delete delete cache in redis. -func (rc *Cache) Delete(key string) error { - _, err := rc.do("DEL", key) - return err -} - -// IsExist check cache's existence in redis. -func (rc *Cache) IsExist(key string) bool { - v, err := redis.Bool(rc.do("EXISTS", key)) - if err != nil { - return false - } - return v -} - -// Incr increase counter in redis. -func (rc *Cache) Incr(key string) error { - _, err := redis.Bool(rc.do("INCRBY", key, 1)) - return err -} - -// Decr decrease counter in redis. -func (rc *Cache) Decr(key string) error { - _, err := redis.Bool(rc.do("INCRBY", key, -1)) - return err -} - -// ClearAll clean all cache in redis. delete this redis collection. -func (rc *Cache) ClearAll() error { - c := rc.p.Get() - defer c.Close() - cachedKeys, err := redis.Strings(c.Do("KEYS", rc.key+":*")) - if err != nil { - return err - } - for _, str := range cachedKeys { - if _, err = c.Do("DEL", str); err != nil { - return err - } - } - return err -} - -// StartAndGC start redis cache adapter. -// config is like {"key":"collection key","conn":"connection info","dbNum":"0"} -// the cache item in redis are stored forever, -// so no gc operation. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - - if _, ok := cf["key"]; !ok { - cf["key"] = DefaultKey - } - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - - // Format redis://@: - cf["conn"] = strings.Replace(cf["conn"], "redis://", "", 1) - if i := strings.Index(cf["conn"], "@"); i > -1 { - cf["password"] = cf["conn"][0:i] - cf["conn"] = cf["conn"][i+1:] - } - - if _, ok := cf["dbNum"]; !ok { - cf["dbNum"] = "0" - } - if _, ok := cf["password"]; !ok { - cf["password"] = "" - } - if _, ok := cf["maxIdle"]; !ok { - cf["maxIdle"] = "3" - } - rc.key = cf["key"] - rc.conninfo = cf["conn"] - rc.dbNum, _ = strconv.Atoi(cf["dbNum"]) - rc.password = cf["password"] - rc.maxIdle, _ = strconv.Atoi(cf["maxIdle"]) - - rc.connectInit() - - c := rc.p.Get() - defer c.Close() - - return c.Err() -} - -// connect to redis. -func (rc *Cache) connectInit() { - dialFunc := func() (c redis.Conn, err error) { - c, err = redis.Dial("tcp", rc.conninfo) - if err != nil { - return nil, err - } - - if rc.password != "" { - if _, err := c.Do("AUTH", rc.password); err != nil { - c.Close() - return nil, err - } - } - - _, selecterr := c.Do("SELECT", rc.dbNum) - if selecterr != nil { - c.Close() - return nil, selecterr - } - return - } - // initialize a new pool - rc.p = &redis.Pool{ - MaxIdle: rc.maxIdle, - IdleTimeout: 180 * time.Second, - Dial: dialFunc, - } -} - -func init() { - cache.Register("redis", NewRedisCache) -} diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go deleted file mode 100644 index 56877f6bf2..0000000000 --- a/cache/redis/redis_test.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package redis - -import ( - "testing" - "time" - - "github.com/astaxie/beego/cache" - "github.com/gomodule/redigo/redis" -) - -func TestRedisCache(t *testing.T) { - bm, err := cache.NewCache("redis", `{"conn": "127.0.0.1:6379"}`) - if err != nil { - t.Error("init err") - } - timeoutDuration := 10 * time.Second - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - time.Sleep(11 * time.Second) - - if bm.IsExist("astaxie") { - t.Error("check err") - } - if err = bm.Put("astaxie", 1, timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } - - if err = bm.Incr("astaxie"); err != nil { - t.Error("Incr Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 2 { - t.Error("get err") - } - - if err = bm.Decr("astaxie"); err != nil { - t.Error("Decr Error", err) - } - - if v, _ := redis.Int(bm.Get("astaxie"), err); v != 1 { - t.Error("get err") - } - bm.Delete("astaxie") - if bm.IsExist("astaxie") { - t.Error("delete err") - } - - //test string - if err = bm.Put("astaxie", "author", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie") { - t.Error("check err") - } - - if v, _ := redis.String(bm.Get("astaxie"), err); v != "author" { - t.Error("get err") - } - - //test GetMulti - if err = bm.Put("astaxie1", "author1", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !bm.IsExist("astaxie1") { - t.Error("check err") - } - - vv := bm.GetMulti([]string{"astaxie", "astaxie1"}) - if len(vv) != 2 { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[0], nil); v != "author" { - t.Error("GetMulti ERROR") - } - if v, _ := redis.String(vv[1], nil); v != "author1" { - t.Error("GetMulti ERROR") - } - - // test clear all - if err = bm.ClearAll(); err != nil { - t.Error("clear all err") - } -} diff --git a/cache/ssdb/ssdb.go b/cache/ssdb/ssdb.go deleted file mode 100644 index fa2ce04bb6..0000000000 --- a/cache/ssdb/ssdb.go +++ /dev/null @@ -1,231 +0,0 @@ -package ssdb - -import ( - "encoding/json" - "errors" - "strconv" - "strings" - "time" - - "github.com/ssdb/gossdb/ssdb" - - "github.com/astaxie/beego/cache" -) - -// Cache SSDB adapter -type Cache struct { - conn *ssdb.Client - conninfo []string -} - -//NewSsdbCache create new ssdb adapter. -func NewSsdbCache() cache.Cache { - return &Cache{} -} - -// Get get value from memcache. -func (rc *Cache) Get(key string) interface{} { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil - } - } - value, err := rc.conn.Get(key) - if err == nil { - return value - } - return nil -} - -// GetMulti get value from memcache. -func (rc *Cache) GetMulti(keys []string) []interface{} { - size := len(keys) - var values []interface{} - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - for i := 0; i < size; i++ { - values = append(values, err) - } - return values - } - } - res, err := rc.conn.Do("multi_get", keys) - resSize := len(res) - if err == nil { - for i := 1; i < resSize; i += 2 { - values = append(values, res[i+1]) - } - return values - } - for i := 0; i < size; i++ { - values = append(values, err) - } - return values -} - -// DelMulti get value from memcache. -func (rc *Cache) DelMulti(keys []string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("multi_del", keys) - return err -} - -// Put put value to memcache. only support string. -func (rc *Cache) Put(key string, value interface{}, timeout time.Duration) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - v, ok := value.(string) - if !ok { - return errors.New("value must string") - } - var resp []string - var err error - ttl := int(timeout / time.Second) - if ttl < 0 { - resp, err = rc.conn.Do("set", key, v) - } else { - resp, err = rc.conn.Do("setx", key, v, ttl) - } - if err != nil { - return err - } - if len(resp) == 2 && resp[0] == "ok" { - return nil - } - return errors.New("bad response") -} - -// Delete delete value in memcache. -func (rc *Cache) Delete(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Del(key) - return err -} - -// Incr increase counter. -func (rc *Cache) Incr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("incr", key, 1) - return err -} - -// Decr decrease counter. -func (rc *Cache) Decr(key string) error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - _, err := rc.conn.Do("incr", key, -1) - return err -} - -// IsExist check value exists in memcache. -func (rc *Cache) IsExist(key string) bool { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return false - } - } - resp, err := rc.conn.Do("exists", key) - if err != nil { - return false - } - if len(resp) == 2 && resp[1] == "1" { - return true - } - return false - -} - -// ClearAll clear all cached in memcache. -func (rc *Cache) ClearAll() error { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - keyStart, keyEnd, limit := "", "", 50 - resp, err := rc.Scan(keyStart, keyEnd, limit) - for err == nil { - size := len(resp) - if size == 1 { - return nil - } - keys := []string{} - for i := 1; i < size; i += 2 { - keys = append(keys, resp[i]) - } - _, e := rc.conn.Do("multi_del", keys) - if e != nil { - return e - } - keyStart = resp[size-2] - resp, err = rc.Scan(keyStart, keyEnd, limit) - } - return err -} - -// Scan key all cached in ssdb. -func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return nil, err - } - } - resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) - if err != nil { - return nil, err - } - return resp, nil -} - -// StartAndGC start memcache adapter. -// config string is like {"conn":"connection info"}. -// if connecting error, return. -func (rc *Cache) StartAndGC(config string) error { - var cf map[string]string - json.Unmarshal([]byte(config), &cf) - if _, ok := cf["conn"]; !ok { - return errors.New("config has no conn key") - } - rc.conninfo = strings.Split(cf["conn"], ";") - if rc.conn == nil { - if err := rc.connectInit(); err != nil { - return err - } - } - return nil -} - -// connect to memcache and keep the connection. -func (rc *Cache) connectInit() error { - conninfoArray := strings.Split(rc.conninfo[0], ":") - host := conninfoArray[0] - port, e := strconv.Atoi(conninfoArray[1]) - if e != nil { - return e - } - var err error - rc.conn, err = ssdb.Connect(host, port) - return err -} - -func init() { - cache.Register("ssdb", NewSsdbCache) -} diff --git a/cache/ssdb/ssdb_test.go b/cache/ssdb/ssdb_test.go deleted file mode 100644 index dd474960aa..0000000000 --- a/cache/ssdb/ssdb_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package ssdb - -import ( - "strconv" - "testing" - "time" - - "github.com/astaxie/beego/cache" -) - -func TestSsdbcacheCache(t *testing.T) { - ssdb, err := cache.NewCache("ssdb", `{"conn": "127.0.0.1:8888"}`) - if err != nil { - t.Error("init err") - } - - // test put and exist - if ssdb.IsExist("ssdb") { - t.Error("check err") - } - timeoutDuration := 10 * time.Second - //timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } - - // Get test done - if err = ssdb.Put("ssdb", "ssdb", timeoutDuration); err != nil { - t.Error("set Error", err) - } - - if v := ssdb.Get("ssdb"); v != "ssdb" { - t.Error("get Error") - } - - //inc/dec test done - if err = ssdb.Put("ssdb", "2", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if err = ssdb.Incr("ssdb"); err != nil { - t.Error("incr Error", err) - } - - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } - - if err = ssdb.Decr("ssdb"); err != nil { - t.Error("decr error") - } - - // test del - if err = ssdb.Put("ssdb", "3", timeoutDuration); err != nil { - t.Error("set Error", err) - } - if v, err := strconv.Atoi(ssdb.Get("ssdb").(string)); err != nil || v != 3 { - t.Error("get err") - } - if err := ssdb.Delete("ssdb"); err == nil { - if ssdb.IsExist("ssdb") { - t.Error("delete err") - } - } - - //test string - if err = ssdb.Put("ssdb", "ssdb", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb") { - t.Error("check err") - } - if v := ssdb.Get("ssdb").(string); v != "ssdb" { - t.Error("get err") - } - - //test GetMulti done - if err = ssdb.Put("ssdb1", "ssdb1", -10*time.Second); err != nil { - t.Error("set Error", err) - } - if !ssdb.IsExist("ssdb1") { - t.Error("check err") - } - vv := ssdb.GetMulti([]string{"ssdb", "ssdb1"}) - if len(vv) != 2 { - t.Error("getmulti error") - } - if vv[0].(string) != "ssdb" { - t.Error("getmulti error") - } - if vv[1].(string) != "ssdb1" { - t.Error("getmulti error") - } - - // test clear all done - if err = ssdb.ClearAll(); err != nil { - t.Error("clear all err") - } - if ssdb.IsExist("ssdb") || ssdb.IsExist("ssdb1") { - t.Error("check err") - } -} diff --git a/client/cache/bloom_filter_cache.go b/client/cache/bloom_filter_cache.go new file mode 100644 index 0000000000..385ae63c80 --- /dev/null +++ b/client/cache/bloom_filter_cache.go @@ -0,0 +1,71 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "errors" + "time" + + "github.com/beego/beego/v2/core/berror" +) + +type BloomFilterCache struct { + Cache + blm BloomFilter + loadFunc func(ctx context.Context, key string) (any, error) + expiration time.Duration // set cache expiration, default never expire +} + +type BloomFilter interface { + Test(data string) bool + Add(data string) +} + +func NewBloomFilterCache(cache Cache, ln func(ctx context.Context, key string) (any, error), blm BloomFilter, + expiration time.Duration, +) (*BloomFilterCache, error) { + if cache == nil || ln == nil || blm == nil { + return nil, berror.Error(InvalidInitParameters, "missing required parameters") + } + + return &BloomFilterCache{ + Cache: cache, + blm: blm, + loadFunc: ln, + expiration: expiration, + }, nil +} + +func (bfc *BloomFilterCache) Get(ctx context.Context, key string) (any, error) { + val, err := bfc.Cache.Get(ctx, key) + if err != nil && !errors.Is(err, ErrKeyNotExist) { + return nil, err + } + if errors.Is(err, ErrKeyNotExist) { + exist := bfc.blm.Test(key) + if exist { + val, err = bfc.loadFunc(ctx, key) + if err != nil { + return nil, berror.Wrap(err, LoadFuncFailed, "cache unable to load data") + } + err = bfc.Put(ctx, key, val, bfc.expiration) + if err != nil { + return val, err + } + } + } + return val, nil +} diff --git a/client/cache/bloom_filter_cache_test.go b/client/cache/bloom_filter_cache_test.go new file mode 100644 index 0000000000..d26d27cce7 --- /dev/null +++ b/client/cache/bloom_filter_cache_test.go @@ -0,0 +1,207 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// nolint +package cache + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + "github.com/bits-and-blooms/bloom/v3" + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/berror" +) + +type MockDB struct { + Db Cache + loadCnt int64 +} + +type BloomFilterMock struct { + *bloom.BloomFilter + lock *sync.RWMutex + concurrent bool +} + +func (b *BloomFilterMock) Add(data string) { + if b.concurrent { + b.lock.Lock() + defer b.lock.Unlock() + } + b.BloomFilter.AddString(data) +} + +func (b *BloomFilterMock) Test(data string) bool { + if b.concurrent { + b.lock.Lock() + defer b.lock.Unlock() + } + return b.BloomFilter.TestString(data) +} + +var ( + mockDB = MockDB{Db: NewMemoryCache(), loadCnt: 0} + mockBloom = &BloomFilterMock{ + BloomFilter: bloom.NewWithEstimates(20000, 0.01), + lock: &sync.RWMutex{}, + concurrent: false, + } + loadFunc = func(ctx context.Context, key string) (any, error) { + mockDB.loadCnt += 1 // flag of number load data from db + v, err := mockDB.Db.Get(context.Background(), key) + if err != nil { + return nil, errors.New("fail") + } + return v, nil + } + cacheUnderlying = NewMemoryCache() +) + +func TestBloomFilterCache_Get(t *testing.T) { + testCases := []struct { + name string + key string + wantVal any + + before func() + after func() + + wantErrCode uint32 + }{ + // case: keys exist in cache + // want: not load data from db + { + name: "not_load_db", + before: func() { + _ = cacheUnderlying.Put(context.Background(), "exist_in_cache", "123", time.Minute) + }, + key: "exist_in_DB", + after: func() { + assert.Equal(t, mockDB.loadCnt, int64(0)) + _ = cacheUnderlying.Delete(context.Background(), "exist_in_cache") + mockDB.loadCnt = 0 + _ = mockDB.Db.ClearAll(context.Background()) + }, + }, + // case: keys not exist in cache, not exist in bloom + // want: not load data from db + { + name: "not_load_db", + before: func() { + _ = mockDB.Db.ClearAll(context.Background()) + _ = mockDB.Db.Put(context.Background(), "exist_in_DB", "exist_in_DB", 0) + mockBloom.AddString("other") + }, + key: "exist_in_DB", + after: func() { + assert.Equal(t, mockDB.loadCnt, int64(0)) + mockBloom.ClearAll() + mockDB.loadCnt = 0 + _ = mockDB.Db.ClearAll(context.Background()) + }, + }, + // case: keys not exist in cache, exist in bloom, exist in db, + // want: load data from db, and set cache + { + name: "load_db", + before: func() { + _ = mockDB.Db.ClearAll(context.Background()) + _ = mockDB.Db.Put(context.Background(), "exist_in_DB", "exist_in_DB", 0) + mockBloom.Add("exist_in_DB") + }, + key: "exist_in_DB", + wantVal: "exist_in_DB", + after: func() { + assert.Equal(t, mockDB.loadCnt, int64(1)) + _ = cacheUnderlying.Delete(context.Background(), "exist_in_DB") + mockBloom.ClearAll() + mockDB.loadCnt = 0 + _ = mockDB.Db.ClearAll(context.Background()) + }, + }, + // case: keys not exist in cache, exist in bloom, not exist in db, + // want: load func error + { + name: "load db fail", + before: func() { + mockBloom.Add("not_exist_in_DB") + }, + after: func() { + assert.Equal(t, mockDB.loadCnt, int64(1)) + mockBloom.ClearAll() + mockDB.loadCnt = 0 + _ = mockDB.Db.ClearAll(context.Background()) + }, + key: "not_exist_in_DB", + wantErrCode: LoadFuncFailed.Code(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.before() + bfc, err := NewBloomFilterCache(cacheUnderlying, loadFunc, mockBloom, time.Minute) + assert.Nil(t, err) + + got, err := bfc.Get(context.Background(), tc.key) + if tc.wantErrCode != 0 { + errCode, _ := berror.FromError(err) + assert.Equal(t, tc.wantErrCode, errCode.Code()) + return + } else { + assert.Nil(t, err) + } + assert.Equal(t, tc.wantVal, got) + + cacheVal, _ := bfc.Cache.Get(context.Background(), tc.key) + assert.Equal(t, tc.wantVal, cacheVal) + tc.after() + }) + } +} + +func ExampleNewBloomFilterCache() { + c := NewMemoryCache() + c, err := NewBloomFilterCache(c, func(ctx context.Context, key string) (any, error) { + return fmt.Sprintf("hello, %s", key), nil + }, &AlwaysExist{}, time.Minute) + if err != nil { + panic(err) + } + + val, err := c.Get(context.Background(), "Beego") + if err != nil { + panic(err) + } + fmt.Println(val) + // Output: + // hello, Beego +} + +type AlwaysExist struct { +} + +func (*AlwaysExist) Test(string) bool { + return true +} + +func (*AlwaysExist) Add(string) { + +} diff --git a/cache/cache.go b/client/cache/cache.go similarity index 61% rename from cache/cache.go rename to client/cache/cache.go index 82585c4e55..f7599dbb99 100644 --- a/cache/cache.go +++ b/client/cache/cache.go @@ -16,7 +16,9 @@ // Usage: // // import( -// "github.com/astaxie/beego/cache" +// +// "github.com/beego/beego/v2/client/cache" +// // ) // // bm, err := cache.NewCache("memory", `{"interval":60}`) @@ -27,17 +29,18 @@ // bm.Get("astaxie") // bm.IsExist("astaxie") // bm.Delete("astaxie") -// -// more docs http://beego.me/docs/module/cache.md package cache import ( - "fmt" + "context" "time" + + "github.com/beego/beego/v2/core/berror" ) // Cache interface contains all behaviors for cache adapter. // usage: +// // cache.Register("file",cache.NewFileCache) // this operation is run in init method of file.go. // c,err := cache.NewCache("file","{....}") // c.Put("key",value, 3600 * time.Second) @@ -47,23 +50,25 @@ import ( // c.Incr("counter") // now is 2 // count := c.Get("counter").(int) type Cache interface { - // get cached value by key. - Get(key string) interface{} + // Get a cached value by key. + Get(ctx context.Context, key string) (interface{}, error) // GetMulti is a batch version of Get. - GetMulti(keys []string) []interface{} - // set cached value with key and expire time. - Put(key string, val interface{}, timeout time.Duration) error - // delete cached value by key. - Delete(key string) error - // increase cached int value by key, as a counter. - Incr(key string) error - // decrease cached int value by key, as a counter. - Decr(key string) error - // check if cached value exists or not. - IsExist(key string) bool - // clear all cache. - ClearAll() error - // start gc routine based on config string settings. + GetMulti(ctx context.Context, keys []string) ([]interface{}, error) + // Put Set a cached value with key and expire time. + Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error + // Delete cached value by key. + // Should not return error if key not found + Delete(ctx context.Context, key string) error + // Incr Increment a cached int value by key, as a counter. + Incr(ctx context.Context, key string) error + // Decr Decrement a cached int value by key, as a counter. + Decr(ctx context.Context, key string) error + // IsExist Check if a cached value exists or not. + // if key is expired, return (false, nil) + IsExist(ctx context.Context, key string) (bool, error) + // ClearAll Clear all cache. + ClearAll(ctx context.Context) error + // StartAndGC Start gc routine based on config string settings. StartAndGC(config string) error } @@ -77,7 +82,7 @@ var adapters = make(map[string]Instance) // it panics. func Register(name string, adapter Instance) { if adapter == nil { - panic("cache: Register adapter is nil") + panic(berror.Error(NilCacheAdapter, "cache: Register adapter is nil").Error()) } if _, ok := adapters[name]; ok { panic("cache: Register called twice for adapter " + name) @@ -85,13 +90,13 @@ func Register(name string, adapter Instance) { adapters[name] = adapter } -// NewCache Create a new cache driver by adapter name and config string. -// config need to be correct JSON as string: {"interval":360}. -// it will start gc automatically. +// NewCache creates a new cache driver by adapter name and config string. +// config: must be in JSON format such as {"interval":360}. +// Starts gc automatically. func NewCache(adapterName, config string) (adapter Cache, err error) { instanceFunc, ok := adapters[adapterName] if !ok { - err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + err = berror.Errorf(UnknownAdapter, "cache: unknown adapter name %s (forgot to import?)", adapterName) return } adapter = instanceFunc() diff --git a/client/cache/cache_test.go b/client/cache/cache_test.go new file mode 100644 index 0000000000..c4a6dec8ce --- /dev/null +++ b/client/cache/cache_test.go @@ -0,0 +1,222 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "math" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCacheIncr(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + assert.Nil(t, err) + // timeoutDuration := 10 * time.Second + + err = bm.Put(context.Background(), "edwardhey", 0, time.Second*20) + assert.Nil(t, err) + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + _ = bm.Incr(context.Background(), "edwardhey") + }() + } + wg.Wait() + val, _ := bm.Get(context.Background(), "edwardhey") + if val.(int) != 10 { + t.Error("Incr err") + } +} + +func TestCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":1}`) + assert.Nil(t, err) + timeoutDuration := 5 * time.Second + if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + if res, _ := bm.IsExist(context.Background(), "astaxie"); !res { + t.Error("check err") + } + + if v, _ := bm.Get(context.Background(), "astaxie"); v.(int) != 1 { + t.Error("get err") + } + + time.Sleep(7 * time.Second) + + if res, _ := bm.IsExist(context.Background(), "astaxie"); res { + t.Error("check err") + } + + if err = bm.Put(context.Background(), "astaxie", 1, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + // test different integer type for incr & decr + testMultiTypeIncrDecr(t, bm, timeoutDuration) + + // test overflow of incr&decr + testIncrOverFlow(t, bm, timeoutDuration) + testDecrOverFlow(t, bm, timeoutDuration) + + assert.Nil(t, bm.Delete(context.Background(), "astaxie")) + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + v, _ := bm.Get(context.Background(), "astaxie") + assert.Equal(t, "author", v) + + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) + + vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + assert.Equal(t, "author", vv[0]) + assert.Equal(t, "author1", vv[1]) + + vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + assert.Equal(t, "author1", vv[1]) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key isn't exist")) +} + +func TestFileCache(t *testing.T) { + bm, err := NewCache("file", `{"CachePath":"cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + assert.Nil(t, err) + timeoutDuration := 10 * time.Second + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) + + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + v, _ := bm.Get(context.Background(), "astaxie") + assert.Equal(t, 1, v) + + // test different integer type for incr & decr + testMultiTypeIncrDecr(t, bm, timeoutDuration) + + // test overflow of incr&decr + testIncrOverFlow(t, bm, timeoutDuration) + testDecrOverFlow(t, bm, timeoutDuration) + + assert.Nil(t, bm.Delete(context.Background(), "astaxie")) + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + // test string + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + v, _ = bm.Get(context.Background(), "astaxie") + assert.Equal(t, "author", v) + + // test GetMulti + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) + + vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + assert.Equal(t, "author", vv[0]) + assert.Equal(t, "author1", vv[1]) + + vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + + assert.Nil(t, vv[0]) + + assert.Equal(t, "author1", vv[1]) + assert.NotNil(t, err) + assert.Nil(t, os.RemoveAll("cache")) +} + +func testMultiTypeIncrDecr(t *testing.T, c Cache, timeout time.Duration) { + testIncrDecr(t, c, 1, 2, timeout) + testIncrDecr(t, c, int32(1), int32(2), timeout) + testIncrDecr(t, c, int64(1), int64(2), timeout) + testIncrDecr(t, c, uint(1), uint(2), timeout) + testIncrDecr(t, c, uint32(1), uint32(2), timeout) + testIncrDecr(t, c, uint64(1), uint64(2), timeout) +} + +func testIncrDecr(t *testing.T, c Cache, beforeIncr interface{}, afterIncr interface{}, timeout time.Duration) { + ctx := context.Background() + key := "incDecKey" + + assert.Nil(t, c.Put(ctx, key, beforeIncr, timeout)) + assert.Nil(t, c.Incr(ctx, key)) + + v, _ := c.Get(ctx, key) + assert.Equal(t, afterIncr, v) + + assert.Nil(t, c.Decr(ctx, key)) + + v, _ = c.Get(ctx, key) + assert.Equal(t, v, beforeIncr) + assert.Nil(t, c.Delete(ctx, key)) +} + +func testIncrOverFlow(t *testing.T, c Cache, timeout time.Duration) { + ctx := context.Background() + key := "incKey" + + assert.Nil(t, c.Put(ctx, key, int64(math.MaxInt64), timeout)) + // int64 + defer func() { + assert.Nil(t, c.Delete(ctx, key)) + }() + assert.NotNil(t, c.Incr(ctx, key)) +} + +func testDecrOverFlow(t *testing.T, c Cache, timeout time.Duration) { + var err error + ctx := context.Background() + key := "decKey" + + // int64 + if err = c.Put(ctx, key, int64(math.MinInt64), timeout); err != nil { + t.Error("Put Error: ", err.Error()) + return + } + defer func() { + if err = c.Delete(ctx, key); err != nil { + t.Errorf("Delete error: %s", err.Error()) + } + }() + if err = c.Decr(ctx, key); err == nil { + t.Error("Decr error") + return + } +} diff --git a/client/cache/calc_utils.go b/client/cache/calc_utils.go new file mode 100644 index 0000000000..f8b7f24ace --- /dev/null +++ b/client/cache/calc_utils.go @@ -0,0 +1,95 @@ +package cache + +import ( + "math" + + "github.com/beego/beego/v2/core/berror" +) + +var ( + ErrIncrementOverflow = berror.Error(IncrementOverflow, "this incr invocation will overflow.") + ErrDecrementOverflow = berror.Error(DecrementOverflow, "this decr invocation will overflow.") + ErrNotIntegerType = berror.Error(NotIntegerType, "item val is not (u)int (u)int32 (u)int64") +) + +const ( + MinUint32 uint32 = 0 + MinUint64 uint64 = 0 +) + +func incr(originVal interface{}) (interface{}, error) { + switch val := originVal.(type) { + case int: + tmp := val + 1 + if val > 0 && tmp < 0 { + return nil, ErrIncrementOverflow + } + return tmp, nil + case int32: + if val == math.MaxInt32 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case int64: + if val == math.MaxInt64 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case uint: + tmp := val + 1 + if tmp < val { + return nil, ErrIncrementOverflow + } + return tmp, nil + case uint32: + if val == math.MaxUint32 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + case uint64: + if val == math.MaxUint64 { + return nil, ErrIncrementOverflow + } + return val + 1, nil + default: + return nil, ErrNotIntegerType + } +} + +func decr(originVal interface{}) (interface{}, error) { + switch val := originVal.(type) { + case int: + tmp := val - 1 + if val < 0 && tmp > 0 { + return nil, ErrDecrementOverflow + } + return tmp, nil + case int32: + if val == math.MinInt32 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case int64: + if val == math.MinInt64 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint: + if val == 0 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint32: + if val == MinUint32 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + case uint64: + if val == MinUint64 { + return nil, ErrDecrementOverflow + } + return val - 1, nil + default: + return nil, ErrNotIntegerType + } +} diff --git a/client/cache/calc_utils_test.go b/client/cache/calc_utils_test.go new file mode 100644 index 0000000000..1f8d33778f --- /dev/null +++ b/client/cache/calc_utils_test.go @@ -0,0 +1,140 @@ +package cache + +import ( + "math" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIncr(t *testing.T) { + // int + var originVal interface{} = int(1) + var updateVal interface{} = int(2) + val, err := incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int(1<<(strconv.IntSize-1) - 1)) + assert.Equal(t, ErrIncrementOverflow, err) + + // int32 + originVal = int32(1) + updateVal = int32(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int32(math.MaxInt32)) + assert.Equal(t, ErrIncrementOverflow, err) + + // int64 + originVal = int64(1) + updateVal = int64(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(int64(math.MaxInt64)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint + originVal = uint(1) + updateVal = uint(2) + val, err = incr(originVal) + + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint(1<<(strconv.IntSize) - 1)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint32 + originVal = uint32(1) + updateVal = uint32(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint32(math.MaxUint32)) + assert.Equal(t, ErrIncrementOverflow, err) + + // uint64 + originVal = uint64(1) + updateVal = uint64(2) + val, err = incr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = incr(uint64(math.MaxUint64)) + assert.Equal(t, ErrIncrementOverflow, err) + // other type + _, err = incr("string") + assert.Equal(t, ErrNotIntegerType, err) +} + +func TestDecr(t *testing.T) { + // int + var originVal interface{} = int(2) + var updateVal interface{} = int(1) + val, err := decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int(-1 << (strconv.IntSize - 1))) + assert.Equal(t, ErrDecrementOverflow, err) + // int32 + originVal = int32(2) + updateVal = int32(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int32(math.MinInt32)) + assert.Equal(t, ErrDecrementOverflow, err) + + // int64 + originVal = int64(2) + updateVal = int64(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(int64(math.MinInt64)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint + originVal = uint(2) + updateVal = uint(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint32 + originVal = uint32(2) + updateVal = uint32(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint32(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // uint64 + originVal = uint64(2) + updateVal = uint64(1) + val, err = decr(originVal) + assert.Nil(t, err) + assert.Equal(t, val, updateVal) + + _, err = decr(uint64(0)) + assert.Equal(t, ErrDecrementOverflow, err) + + // other type + _, err = decr("string") + assert.Equal(t, ErrNotIntegerType, err) +} diff --git a/cache/conv.go b/client/cache/conv.go similarity index 90% rename from cache/conv.go rename to client/cache/conv.go index 8780058640..158f7f413f 100644 --- a/cache/conv.go +++ b/client/cache/conv.go @@ -19,7 +19,7 @@ import ( "strconv" ) -// GetString convert interface to string. +// GetString converts interface to string. func GetString(v interface{}) string { switch result := v.(type) { case string: @@ -34,7 +34,7 @@ func GetString(v interface{}) string { return "" } -// GetInt convert interface to int. +// GetInt converts interface to int. func GetInt(v interface{}) int { switch result := v.(type) { case int: @@ -52,7 +52,7 @@ func GetInt(v interface{}) int { return 0 } -// GetInt64 convert interface to int64. +// GetInt64 converts interface to int64. func GetInt64(v interface{}) int64 { switch result := v.(type) { case int: @@ -71,7 +71,7 @@ func GetInt64(v interface{}) int64 { return 0 } -// GetFloat64 convert interface to float64. +// GetFloat64 converts interface to float64. func GetFloat64(v interface{}) float64 { switch result := v.(type) { case float64: @@ -85,7 +85,7 @@ func GetFloat64(v interface{}) float64 { return 0 } -// GetBool convert interface to bool. +// GetBool converts interface to bool. func GetBool(v interface{}) bool { switch result := v.(type) { case bool: diff --git a/client/cache/conv_test.go b/client/cache/conv_test.go new file mode 100644 index 0000000000..c261c440d9 --- /dev/null +++ b/client/cache/conv_test.go @@ -0,0 +1,89 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetString(t *testing.T) { + t1 := "test1" + + assert.Equal(t, "test1", GetString(t1)) + t2 := []byte("test2") + assert.Equal(t, "test2", GetString(t2)) + t3 := 1 + assert.Equal(t, "1", GetString(t3)) + var t4 int64 = 1 + assert.Equal(t, "1", GetString(t4)) + t5 := 1.1 + assert.Equal(t, "1.1", GetString(t5)) + assert.Equal(t, "", GetString(nil)) +} + +func TestGetInt(t *testing.T) { + t1 := 1 + assert.Equal(t, 1, GetInt(t1)) + var t2 int32 = 32 + assert.Equal(t, 32, GetInt(t2)) + + var t3 int64 = 64 + assert.Equal(t, 64, GetInt(t3)) + t4 := "128" + + assert.Equal(t, 128, GetInt(t4)) + assert.Equal(t, 0, GetInt(nil)) +} + +func TestGetInt64(t *testing.T) { + var i int64 = 1 + t1 := 1 + assert.Equal(t, i, GetInt64(t1)) + var t2 int32 = 1 + + assert.Equal(t, i, GetInt64(t2)) + var t3 int64 = 1 + assert.Equal(t, i, GetInt64(t3)) + t4 := "1" + assert.Equal(t, i, GetInt64(t4)) + assert.Equal(t, int64(0), GetInt64(nil)) +} + +func TestGetFloat64(t *testing.T) { + f := 1.11 + var t1 float32 = 1.11 + assert.Equal(t, f, GetFloat64(t1)) + t2 := 1.11 + assert.Equal(t, f, GetFloat64(t2)) + t3 := "1.11" + assert.Equal(t, f, GetFloat64(t3)) + + var f2 float64 = 1 + t4 := 1 + assert.Equal(t, f2, GetFloat64(t4)) + + assert.Equal(t, float64(0), GetFloat64(nil)) +} + +func TestGetBool(t *testing.T) { + t1 := true + assert.True(t, GetBool(t1)) + t2 := "true" + assert.True(t, GetBool(t2)) + + assert.False(t, GetBool(nil)) +} diff --git a/client/cache/error_code.go b/client/cache/error_code.go new file mode 100644 index 0000000000..2aa9ffc83e --- /dev/null +++ b/client/cache/error_code.go @@ -0,0 +1,199 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "github.com/beego/beego/v2/core/berror" +) + +var NilCacheAdapter = berror.DefineCode(4002001, moduleName, "NilCacheAdapter", ` +It means that you register cache adapter by pass nil. +A cache adapter is an instance of Cache interface. +`) + +var DuplicateAdapter = berror.DefineCode(4002002, moduleName, "DuplicateAdapter", ` +You register two adapter with same name. In beego cache module, one name one adapter. +Once you got this error, please check the error stack, search adapter +`) + +var UnknownAdapter = berror.DefineCode(4002003, moduleName, "UnknownAdapter", ` +Unknown adapter, do you forget to register the adapter? +You must register adapter before use it. For example, if you want to use redis implementation, +you must import the cache/redis package. +`) + +var IncrementOverflow = berror.DefineCode(4002004, moduleName, "IncrementOverflow", ` +The increment operation will overflow. +`) + +var DecrementOverflow = berror.DefineCode(4002005, moduleName, "DecrementOverflow", ` +The decrement operation will overflow. +`) + +var NotIntegerType = berror.DefineCode(4002006, moduleName, "NotIntegerType", ` +The type of value is not (u)int (u)int32 (u)int64. +When you want to call Incr or Decr function of Cache API, you must confirm that the value's type is one of (u)int (u)int32 (u)int64. +`) + +var InvalidFileCacheDirectoryLevelCfg = berror.DefineCode(4002007, moduleName, "InvalidFileCacheDirectoryLevelCfg", ` +You pass invalid DirectoryLevel parameter when you try to StartAndGC file cache instance. +This parameter must be a integer, and please check your input. +`) + +var InvalidFileCacheEmbedExpiryCfg = berror.DefineCode(4002008, moduleName, "InvalidFileCacheEmbedExpiryCfg", ` +You pass invalid EmbedExpiry parameter when you try to StartAndGC file cache instance. +This parameter must be a integer, and please check your input. +`) + +var CreateFileCacheDirFailed = berror.DefineCode(4002009, moduleName, "CreateFileCacheDirFailed", ` +Beego failed to create file cache directory. There are two cases: +1. You pass invalid CachePath parameter. Please check your input. +2. Beego doesn't have the permission to create this directory. Please check your file mode. +`) + +var InvalidFileCachePath = berror.DefineCode(4002010, moduleName, "InvalidFilePath", ` +The file path of FileCache is invalid. Please correct the config. +`) + +var ReadFileCacheContentFailed = berror.DefineCode(4002011, moduleName, "ReadFileCacheContentFailed", ` +Usually you won't got this error. It means that Beego cannot read the data from the file. +You need to check whether the file exist. Sometimes it may be deleted by other processes. +If the file exists, please check the permission that Beego is able to read data from the file. +`) + +var InvalidGobEncodedData = berror.DefineCode(4002012, moduleName, "InvalidEncodedData", ` +The data is invalid. When you try to decode the invalid data, you got this error. +Please confirm that the data is encoded by GOB correctly. +`) + +var GobEncodeDataFailed = berror.DefineCode(4002013, moduleName, "GobEncodeDataFailed", ` +Beego could not encode the data to GOB byte array. In general, the data type is invalid. +For example, GOB doesn't support function type. +Basic types, string, structure, structure pointer are supported. +`) + +var KeyExpired = berror.DefineCode(4002014, moduleName, "KeyExpired", ` +Cache key is expired. +You should notice that, a key is expired and then it may be deleted by GC goroutine. +So when you query a key which may be expired, you may got this code, or KeyNotExist. +`) + +var KeyNotExist = berror.DefineCode(4002015, moduleName, "KeyNotExist", ` +Key not found. +`) + +var MultiGetFailed = berror.DefineCode(4002016, moduleName, "MultiGetFailed", ` +Get multiple keys failed. Please check the detail msg to find out the root cause. +`) + +var InvalidMemoryCacheCfg = berror.DefineCode(4002017, moduleName, "InvalidMemoryCacheCfg", ` +The config is invalid. Please check your input. It must be a json string. +`) + +var InvalidMemCacheCfg = berror.DefineCode(4002018, moduleName, "InvalidMemCacheCfg", ` +The config is invalid. Please check your input, it must be json string and contains "conn" field. +`) + +var InvalidMemCacheValue = berror.DefineCode(4002019, moduleName, "InvalidMemCacheValue", ` +The value must be string or byte[], please check your input. +`) + +var InvalidRedisCacheCfg = berror.DefineCode(4002020, moduleName, "InvalidRedisCacheCfg", ` +The config must be json string, and has "conn" field. +`) + +var InvalidSsdbCacheCfg = berror.DefineCode(4002021, moduleName, "InvalidSsdbCacheCfg", ` +The config must be json string, and has "conn" field. The value of "conn" field should be "host:port". +"port" must be a valid integer. +`) + +var InvalidSsdbCacheValue = berror.DefineCode(4002022, moduleName, "InvalidSsdbCacheValue", ` +SSDB cache only accept string value. Please check your input. +`) + +var InvalidLoadFunc = berror.DefineCode(4002023, moduleName, "InvalidLoadFunc", ` +Invalid load function for read-through pattern decorator. +You should pass a valid(non-nil) load function when initiate the decorator instance. +`) + +var LoadFuncFailed = berror.DefineCode(4002024, moduleName, "LoadFuncFailed", ` +Failed to load data, please check whether the loadfunc is correct +`) + +var InvalidInitParameters = berror.DefineCode(4002025, moduleName, "InvalidInitParameters", ` +Invalid init cache parameters. +You can check the related function to confirm that if you pass correct parameters or configure to initiate a Cache instance. +`) + +var PersistCacheFailed = berror.DefineCode(4002026, moduleName, "PersistCacheFailed", ` +Failed to execute the StoreFunc. +Please check the log to make sure the StoreFunc works for the specific key and value. +`) + +var DeleteFileCacheItemFailed = berror.DefineCode(5002001, moduleName, "DeleteFileCacheItemFailed", ` +Beego try to delete file cache item failed. +Please check whether Beego generated file correctly. +And then confirm whether this file is already deleted by other processes or other people. +`) + +var MemCacheCurdFailed = berror.DefineCode(5002002, moduleName, "MemCacheError", ` +When you want to get, put, delete key-value from remote memcache servers, you may get error: +1. You pass invalid servers address, so Beego could not connect to remote server; +2. The servers address is correct, but there is some net issue. Typically there is some firewalls between application and memcache server; +3. Key is invalid. The key's length should be less than 250 and must not contains special characters; +4. The response from memcache server is invalid; +`) + +var RedisCacheCurdFailed = berror.DefineCode(5002003, moduleName, "RedisCacheCurdFailed", ` +When Beego uses client to send request to redis server, it failed. +1. The server addresses is invalid; +2. Network issue, firewall issue or network is unstable; +3. Client failed to manage connection. In extreme cases, Beego's redis client didn't maintain connections correctly, for example, Beego try to send request via closed connection; +4. The request are huge and redis server spent too much time to process it, and client is timeout; + +In general, if you always got this error whatever you do, in most cases, it was caused by network issue. +You could check your network state, and confirm that firewall rules are correct. +`) + +var InvalidConnection = berror.DefineCode(5002004, moduleName, "InvalidConnection", ` +The connection is invalid. Please check your connection info, network, firewall. +You could simply uses ping, telnet or write some simple tests to test network. +`) + +var DialFailed = berror.DefineCode(5002005, moduleName, "DialFailed", ` +When Beego try to dial to remote servers, it failed. Please check your connection info and network state, server state. +`) + +var SsdbCacheCurdFailed = berror.DefineCode(5002006, moduleName, "SsdbCacheCurdFailed", ` +When you try to use SSDB cache, it failed. There are many cases: +1. servers unavailable; +2. network issue, including network unstable, firewall; +3. connection issue; +4. request are huge and servers spent too much time to process it, got timeout; +`) + +var SsdbBadResponse = berror.DefineCode(5002007, moduleName, "SsdbBadResponse", ` +The response from SSDB server is invalid. +Usually it indicates something wrong on server side. +`) + +var DeleteFailed = berror.DefineCode(5002008, moduleName, "DeleteFailed", ` +Beego attempt to delete cache item failed. Please check if the target key is correct. +`) + +var ( + ErrKeyExpired = berror.Error(KeyExpired, "the key is expired") + ErrKeyNotExist = berror.Error(KeyNotExist, "the key isn't exist") +) diff --git a/client/cache/file.go b/client/cache/file.go new file mode 100644 index 0000000000..e665cdc16a --- /dev/null +++ b/client/cache/file.go @@ -0,0 +1,339 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/beego/beego/v2/core/berror" +) + +// FileCacheItem is basic unit of file cache adapter which +// contains data and expire time. +type FileCacheItem struct { + Data interface{} + Lastaccess time.Time + Expired time.Time +} + +// FileCache Config +var ( + FileCachePath = "cache" // cache directory + FileCacheFileSuffix = ".bin" // cache file suffix + FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. + FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. +) + +// FileCache is cache adapter for file storage. +type FileCache struct { + CachePath string + FileSuffix string + DirectoryLevel int + EmbedExpiry int +} + +// NewFileCache creates a new file cache with no config. +// The level and expiry need to be set in the method StartAndGC as config string. +func NewFileCache() Cache { + // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} + return &FileCache{} +} + +// StartAndGC starts gc for file cache. +// config must be in the format {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"} +func (fc *FileCache) StartAndGC(config string) error { + cfg := make(map[string]string) + err := json.Unmarshal([]byte(config), &cfg) + if err != nil { + return err + } + + const cpKey = "CachePath" + const fsKey = "FileSuffix" + const dlKey = "DirectoryLevel" + const eeKey = "EmbedExpiry" + + if _, ok := cfg[cpKey]; !ok { + cfg[cpKey] = FileCachePath + } + + if _, ok := cfg[fsKey]; !ok { + cfg[fsKey] = FileCacheFileSuffix + } + + if _, ok := cfg[dlKey]; !ok { + cfg[dlKey] = strconv.Itoa(FileCacheDirectoryLevel) + } + + if _, ok := cfg[eeKey]; !ok { + cfg[eeKey] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg[cpKey] + fc.FileSuffix = cfg[fsKey] + fc.DirectoryLevel, err = strconv.Atoi(cfg[dlKey]) + if err != nil { + return berror.Wrapf(err, InvalidFileCacheDirectoryLevelCfg, + "invalid directory level config, please check your input, it must be integer: %s", cfg[dlKey]) + } + fc.EmbedExpiry, err = strconv.Atoi(cfg[eeKey]) + if err != nil { + return berror.Wrapf(err, InvalidFileCacheEmbedExpiryCfg, + "invalid embed expiry config, please check your input, it must be integer: %s", cfg[eeKey]) + } + return fc.Init() +} + +// Init makes new a dir for file cache if it does not already exist +func (fc *FileCache) Init() error { + ok, err := exists(fc.CachePath) + if err != nil || ok { + return err + } + err = os.MkdirAll(fc.CachePath, os.ModePerm) + if err != nil { + return berror.Wrapf(err, CreateFileCacheDirFailed, + "could not create directory, please check the config [%s] and file mode.", fc.CachePath) + } + return nil +} + +// getCacheFileName returns a md5 encoded file name. +func (fc *FileCache) getCacheFileName(key string) (string, error) { + m := sha256.New() + _, _ = io.WriteString(m, key) + keySha256 := hex.EncodeToString(m.Sum(nil)) + cachePath := fc.CachePath + switch fc.DirectoryLevel { + case 2: + cachePath = filepath.Join(cachePath, keySha256[0:2], keySha256[2:4]) + case 1: + cachePath = filepath.Join(cachePath, keySha256[0:2]) + } + ok, err := exists(cachePath) + if err != nil { + return "", err + } + if !ok { + fmt.Printf("cachePath: %s\n", cachePath) + err = os.MkdirAll(cachePath, os.ModePerm) + if err != nil { + return "", berror.Wrapf(err, CreateFileCacheDirFailed, + "could not create the directory: %s", cachePath) + } + } + + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keySha256, fc.FileSuffix)), nil +} + +// Get value from file cache. +// if nonexistent or expired return an empty string. +func (fc *FileCache) Get(ctx context.Context, key string) (interface{}, error) { + fn, err := fc.getCacheFileName(key) + if err != nil { + return nil, err + } + fileData, err := FileGetContents(fn) + if err != nil { + return nil, err + } + + var to FileCacheItem + err = GobDecode(fileData, &to) + if err != nil { + return nil, err + } + + if to.Expired.Before(time.Now()) { + return nil, ErrKeyExpired + } + return to.Data, nil +} + +// GetMulti gets values from file cache. +// if nonexistent or expired return an empty string. +func (fc *FileCache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { + rc := make([]interface{}, len(keys)) + keysErr := make([]string, 0) + + for i, ki := range keys { + val, err := fc.Get(context.Background(), ki) + if err != nil { + keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, err.Error())) + continue + } + rc[i] = val + } + + if len(keysErr) == 0 { + return rc, nil + } + return rc, berror.Error(MultiGetFailed, strings.Join(keysErr, "; ")) +} + +// Put value into file cache. +// timeout: how long this file should be kept in ms +// if timeout equals fc.EmbedExpiry(default is 0), cache this item forever. +func (fc *FileCache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { + gob.Register(val) + + item := FileCacheItem{Data: val} + if timeout == time.Duration(fc.EmbedExpiry) { + item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years + } else { + item.Expired = time.Now().Add(timeout) + } + item.Lastaccess = time.Now() + data, err := GobEncode(item) + if err != nil { + return err + } + + fn, err := fc.getCacheFileName(key) + + if err != nil { + return err + } + return FilePutContents(fn, data) +} + +// Delete file cache value. +func (fc *FileCache) Delete(ctx context.Context, key string) error { + filename, err := fc.getCacheFileName(key) + if err != nil { + return err + } + if ok, _ := exists(filename); ok { + err = os.Remove(filename) + if err != nil { + return berror.Wrapf(err, DeleteFileCacheItemFailed, + "can not delete this file cache key-value, key is %s and file name is %s", key, filename) + } + } + return nil +} + +// Incr increases cached int value. +// fc value is saved forever unless deleted. +func (fc *FileCache) Incr(ctx context.Context, key string) error { + data, err := fc.Get(context.Background(), key) + if err != nil { + return err + } + + val, err := incr(data) + if err != nil { + return err + } + + return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry)) +} + +// Decr decreases cached int value. +func (fc *FileCache) Decr(ctx context.Context, key string) error { + data, err := fc.Get(context.Background(), key) + if err != nil { + return err + } + + val, err := decr(data) + if err != nil { + return err + } + + return fc.Put(context.Background(), key, val, time.Duration(fc.EmbedExpiry)) +} + +// IsExist checks if value exists. +func (fc *FileCache) IsExist(ctx context.Context, key string) (bool, error) { + fn, err := fc.getCacheFileName(key) + if err != nil { + return false, err + } + return exists(fn) +} + +// ClearAll cleans cached files (not implemented) +func (fc *FileCache) ClearAll(context.Context) error { + return nil +} + +// Check if a file exists +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, berror.Wrapf(err, InvalidFileCachePath, "file cache path is invalid: %s", path) +} + +// FileGetContents Reads bytes from a file. +// if non-existent, create this file. +func FileGetContents(filename string) ([]byte, error) { + data, err := os.ReadFile(filename) + if err != nil { + return nil, berror.Wrapf(err, ReadFileCacheContentFailed, + "could not read the data from the file: %s, "+ + "please confirm that file exist and Beego has the permission to read the content.", filename) + } + return data, nil +} + +// FilePutContents puts bytes into a file. +// if non-existent, create this file. +func FilePutContents(filename string, content []byte) error { + return os.WriteFile(filename, content, os.ModePerm) +} + +// GobEncode Gob encodes a file cache item. +func GobEncode(data interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(data) + if err != nil { + return nil, berror.Wrap(err, GobEncodeDataFailed, "could not encode this data") + } + return buf.Bytes(), nil +} + +// GobDecode Gob decodes a file cache item. +func GobDecode(data []byte, to *FileCacheItem) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + err := dec.Decode(&to) + if err != nil { + return berror.Wrap(err, InvalidGobEncodedData, + "could not decode this data to FileCacheItem. Make sure that the data is encoded by GOB.") + } + return nil +} + +func init() { + Register("file", NewFileCache) +} diff --git a/client/cache/file_test.go b/client/cache/file_test.go new file mode 100644 index 0000000000..db2b8ada87 --- /dev/null +++ b/client/cache/file_test.go @@ -0,0 +1,114 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "fmt" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "testing" +) + +func TestFileCacheStartAndGC(t *testing.T) { + fc := NewFileCache().(*FileCache) + err := fc.StartAndGC(`{`) + assert.NotNil(t, err) + err = fc.StartAndGC(`{}`) + assert.Nil(t, err) + _, err = fc.getCacheFileName("key1") + assert.Nil(t, err) + + assert.Equal(t, fc.CachePath, FileCachePath) + assert.Equal(t, fc.DirectoryLevel, FileCacheDirectoryLevel) + assert.Equal(t, fc.EmbedExpiry, int(FileCacheEmbedExpiry)) + assert.Equal(t, fc.FileSuffix, FileCacheFileSuffix) + + err = fc.StartAndGC(`{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`) + // could not create dir + assert.NotNil(t, err) + + str := getTestCacheFilePath() + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"0"}`, str)) + assert.Nil(t, err) + assert.Equal(t, fc.CachePath, str) + assert.Equal(t, fc.DirectoryLevel, 2) + assert.Equal(t, fc.EmbedExpiry, 0) + assert.Equal(t, fc.FileSuffix, ".bin") + _, err = fc.getCacheFileName("key1") + assert.Nil(t, err) + + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"aaa","EmbedExpiry":"0"}`, str)) + assert.NotNil(t, err) + + err = fc.StartAndGC(fmt.Sprintf(`{"CachePath":"%s","FileSuffix":".bin","DirectoryLevel":"2","EmbedExpiry":"aaa"}`, str)) + assert.NotNil(t, err) + + _, err = fc.getCacheFileName("key1") + assert.Nil(t, err) +} + +func TestFileCacheInit(t *testing.T) { + fc := NewFileCache().(*FileCache) + fc.CachePath = "////aaa" + err := fc.Init() + assert.NotNil(t, err) + fc.CachePath = getTestCacheFilePath() + err = fc.Init() + assert.Nil(t, err) +} + +func TestFileGetContents(t *testing.T) { + _, err := FileGetContents("/bin/aaa") + assert.NotNil(t, err) + fn := filepath.Join(os.TempDir(), "fileCache.txt") + f, err := os.Create(fn) + assert.Nil(t, err) + _, err = f.WriteString("text") + assert.Nil(t, err) + data, err := FileGetContents(fn) + assert.Nil(t, err) + assert.Equal(t, "text", string(data)) +} + +func TestGobEncodeDecode(t *testing.T) { + _, err := GobEncode(func() { + fmt.Print("test func") + }) + assert.NotNil(t, err) + data, err := GobEncode(&FileCacheItem{ + Data: "hello", + }) + assert.Nil(t, err) + err = GobDecode([]byte("wrong data"), &FileCacheItem{}) + assert.NotNil(t, err) + dci := &FileCacheItem{} + err = GobDecode(data, dci) + assert.Nil(t, err) + assert.Equal(t, "hello", dci.Data) +} + +func TestFileCacheDelete(t *testing.T) { + fc := NewFileCache() + err := fc.StartAndGC(`{}`) + assert.Nil(t, err) + err = fc.Delete(context.Background(), "my-key") + assert.Nil(t, err) +} + +func getTestCacheFilePath() string { + return filepath.Join(os.TempDir(), "test", "file.txt") +} diff --git a/client/cache/memcache/memcache.go b/client/cache/memcache/memcache.go new file mode 100644 index 0000000000..4c6cf0dd26 --- /dev/null +++ b/client/cache/memcache/memcache.go @@ -0,0 +1,161 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package memcache for cache provider +// +// depend on github.com/bradfitz/gomemcache/memcache +// +// go install github.com/bradfitz/gomemcache/memcache +// +// Usage: +// import( +// +// _ "github.com/beego/beego/v2/client/cache/memcache" +// "github.com/beego/beego/v2/client/cache" +// +// ) +// +// bm, err := cache.NewCache("memcache", `{"conn":"127.0.0.1:11211"}`) +package memcache + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/bradfitz/gomemcache/memcache" + + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" +) + +// Cache Memcache adapter. +type Cache struct { + conn *memcache.Client + conninfo []string +} + +// NewMemCache creates a new memcache adapter. +func NewMemCache() cache.Cache { + return &Cache{} +} + +// Get get value from memcache. +func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { + if item, err := rc.conn.Get(key); err == nil { + return item.Value, nil + } else { + return nil, berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not read data from memcache, please check your key, network and connection. Root cause: %s", + err.Error()) + } +} + +// GetMulti gets a value from a key in memcache. +func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { + rv := make([]interface{}, len(keys)) + + mv, err := rc.conn.GetMulti(keys) + if err != nil { + return rv, berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not read multiple key-values from memcache, "+ + "please check your keys, network and connection. Root cause: %s", + err.Error()) + } + + keysErr := make([]string, 0) + for i, ki := range keys { + if _, ok := mv[ki]; !ok { + keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "key not exist")) + continue + } + rv[i] = mv[ki].Value + } + + if len(keysErr) == 0 { + return rv, nil + } + return rv, berror.Error(cache.MultiGetFailed, strings.Join(keysErr, "; ")) +} + +// Put puts a value into memcache. +func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { + item := memcache.Item{Key: key, Expiration: int32(timeout / time.Second)} + if v, ok := val.([]byte); ok { + item.Value = v + } else if str, ok := val.(string); ok { + item.Value = []byte(str) + } else { + return berror.Errorf(cache.InvalidMemCacheValue, + "the value must be string or byte[]. key: %s, value:%v", key, val) + } + return berror.Wrapf(rc.conn.Set(&item), cache.MemCacheCurdFailed, + "could not put key-value to memcache, key: %s", key) +} + +// Delete deletes a value in memcache. +func (rc *Cache) Delete(ctx context.Context, key string) error { + return berror.Wrapf(rc.conn.Delete(key), cache.MemCacheCurdFailed, + "could not delete key-value from memcache, key: %s", key) +} + +// Incr increases counter. +func (rc *Cache) Incr(ctx context.Context, key string) error { + _, err := rc.conn.Increment(key, 1) + return berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not increase value for key: %s", key) +} + +// Decr decreases counter. +func (rc *Cache) Decr(ctx context.Context, key string) error { + _, err := rc.conn.Decrement(key, 1) + return berror.Wrapf(err, cache.MemCacheCurdFailed, + "could not decrease value for key: %s", key) +} + +// IsExist checks if a value exists in memcache. +func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { + _, err := rc.Get(ctx, key) + return err == nil, err +} + +// ClearAll clears all cache in memcache. +func (rc *Cache) ClearAll(context.Context) error { + return berror.Wrap(rc.conn.FlushAll(), cache.MemCacheCurdFailed, + "try to clear all key-value pairs failed") +} + +// StartAndGC starts the memcache adapter. +// config: must be in the format {"conn":"connection info"}. +// If an error occurs during connecting, an error is returned +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + if err := json.Unmarshal([]byte(config), &cf); err != nil { + return berror.Wrapf(err, cache.InvalidMemCacheCfg, + "could not unmarshal this config, it must be valid json stringP: %s", config) + } + + if _, ok := cf["conn"]; !ok { + return berror.Errorf(cache.InvalidMemCacheCfg, `config must contains "conn" field: %s`, config) + } + rc.conninfo = strings.Split(cf["conn"], ";") + rc.conn = memcache.New(rc.conninfo...) + return nil +} + +func init() { + cache.Register("memcache", NewMemCache) +} diff --git a/client/cache/memcache/memcache_test.go b/client/cache/memcache/memcache_test.go new file mode 100644 index 0000000000..090262f99b --- /dev/null +++ b/client/cache/memcache/memcache_test.go @@ -0,0 +1,228 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memcache + +import ( + "context" + "errors" + "fmt" + "os" + "strconv" + "strings" + "testing" + "time" + + _ "github.com/bradfitz/gomemcache/memcache" + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" +) + +func TestMemcacheCache(t *testing.T) { + addr := os.Getenv("MEMCACHE_ADDR") + if addr == "" { + addr = "127.0.0.1:11211" + } + + bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, addr)) + assert.Nil(t, err) + + timeoutDuration := 10 * time.Second + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "1", timeoutDuration)) + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + time.Sleep(11 * time.Second) + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "1", timeoutDuration)) + + val, _ := bm.Get(context.Background(), "astaxie") + v, err := strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) + + assert.Nil(t, bm.Incr(context.Background(), "astaxie")) + + val, _ = bm.Get(context.Background(), "astaxie") + v, err = strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 2, v) + + assert.Nil(t, bm.Decr(context.Background(), "astaxie")) + + val, _ = bm.Get(context.Background(), "astaxie") + v, err = strconv.Atoi(string(val.([]byte))) + assert.Nil(t, err) + assert.Equal(t, 1, v) + assert.Nil(t, bm.Delete(context.Background(), "astaxie")) + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) + // test string + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + val, _ = bm.Get(context.Background(), "astaxie") + vs := val.([]byte) + assert.Equal(t, "author", string(vs)) + + // test GetMulti + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) + + vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + + if string(vv[0].([]byte)) != "author" && string(vv[0].([]byte)) != "author1" { + t.Error("GetMulti ERROR") + } + if string(vv[1].([]byte)) != "author1" && string(vv[1].([]byte)) != "author" { + t.Error("GetMulti ERROR") + } + + vv, err = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + + assert.Equal(t, "author1", string(vv[1].([]byte))) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key not exist")) + + assert.Nil(t, bm.ClearAll(context.Background())) + // test clear all +} + +func TestReadThroughCache_Memcache_Get(t *testing.T) { + bm, err := cache.NewCache("memcache", fmt.Sprintf(`{"conn": "%s"}`, "127.0.0.1:11211")) + assert.Nil(t, err) + + testReadThroughCacheGet(t, bm) +} + +func testReadThroughCacheGet(t *testing.T, bm cache.Cache) { + testCases := []struct { + name string + key string + value string + cache cache.Cache + wantErr error + }{ + { + name: "Get load err", + key: "key0", + cache: func() cache.Cache { + kvs := map[string]any{"key0": "value0"} + db := &MockOrm{kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + v, er := db.Load(key) + if er != nil { + return nil, er + } + val := []byte(v.(string)) + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + return c + }(), + wantErr: func() error { + err := errors.New("the key not exist") + return berror.Wrap( + err, cache.LoadFuncFailed, "cache unable to load data") + }(), + }, + { + name: "Get cache exist", + key: "key1", + value: "value1", + cache: func() cache.Cache { + keysMap := map[string]int{"key1": 1} + kvs := map[string]any{"key1": "value1"} + db := &MockOrm{keysMap: keysMap, kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + v, er := db.Load(key) + if er != nil { + return nil, er + } + val := []byte(v.(string)) + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + err = c.Put(context.Background(), "key1", "value1", 3*time.Second) + assert.Nil(t, err) + return c + }(), + }, + { + name: "Get loadFunc exist", + key: "key2", + value: "value2", + cache: func() cache.Cache { + keysMap := map[string]int{"key2": 1} + kvs := map[string]any{"key2": "value2"} + db := &MockOrm{keysMap: keysMap, kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + v, er := db.Load(key) + if er != nil { + return nil, er + } + val := []byte(v.(string)) + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + return c + }(), + }, + } + _, err := cache.NewReadThroughCache(bm, 3*time.Second, nil) + assert.Equal(t, berror.Error(cache.InvalidLoadFunc, "loadFunc cannot be nil"), err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + bs := []byte(tc.value) + c := tc.cache + val, err := c.Get(context.Background(), tc.key) + if err != nil { + assert.EqualError(t, tc.wantErr, err.Error()) + return + } + assert.Equal(t, bs, val) + }) + } +} + +type MockOrm struct { + keysMap map[string]int + kvs map[string]any +} + +func (m *MockOrm) Load(key string) (any, error) { + _, ok := m.keysMap[key] + if !ok { + return nil, errors.New("the key not exist") + } + return m.kvs[key], nil +} diff --git a/client/cache/memory.go b/client/cache/memory.go new file mode 100644 index 0000000000..0caa3a8c71 --- /dev/null +++ b/client/cache/memory.go @@ -0,0 +1,234 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/beego/beego/v2/core/berror" +) + +// DefaultEvery sets a timer for how often to recycle the expired cache items in memory (in seconds) +var DefaultEvery = 60 // 1 minute + +// MemoryItem stores memory cache item. +type MemoryItem struct { + val interface{} + createdTime time.Time + lifespan time.Duration +} + +func (mi *MemoryItem) isExpire() bool { + // 0 means forever + if mi.lifespan == 0 { + return false + } + return time.Since(mi.createdTime) > mi.lifespan +} + +// MemoryCache is a memory cache adapter. +// Contains a RW locker for safe map storage. +type MemoryCache struct { + sync.RWMutex + dur time.Duration + items map[string]*MemoryItem + Every int // run an expiration check Every clock time +} + +// NewMemoryCache returns a new MemoryCache. +func NewMemoryCache() Cache { + cache := MemoryCache{items: make(map[string]*MemoryItem)} + return &cache +} + +// Get returns cache from memory. +// If non-existent or expired, return nil. +func (bc *MemoryCache) Get(ctx context.Context, key string) (interface{}, error) { + bc.RLock() + defer bc.RUnlock() + if itm, ok := bc.items[key]; ok { + if itm.isExpire() { + return nil, ErrKeyExpired + } + return itm.val, nil + } + return nil, ErrKeyNotExist +} + +// GetMulti gets caches from memory. +// If non-existent or expired, return nil. +func (bc *MemoryCache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { + rc := make([]interface{}, len(keys)) + keysErr := make([]string, 0) + + for i, ki := range keys { + val, err := bc.Get(context.Background(), ki) + if err != nil { + keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, err.Error())) + continue + } + rc[i] = val + } + + if len(keysErr) == 0 { + return rc, nil + } + return rc, berror.Error(MultiGetFailed, strings.Join(keysErr, "; ")) +} + +// Put puts cache into memory. +// If lifespan is 0, it will never overwrite this value unless restarted +func (bc *MemoryCache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { + bc.Lock() + defer bc.Unlock() + bc.items[key] = &MemoryItem{ + val: val, + createdTime: time.Now(), + lifespan: timeout, + } + return nil +} + +// Delete cache in memory. +// If the key is not found, it will not return error +func (bc *MemoryCache) Delete(ctx context.Context, key string) error { + bc.Lock() + defer bc.Unlock() + delete(bc.items, key) + return nil +} + +// Incr increases cache counter in memory. +// Supports int,int32,int64,uint,uint32,uint64. +func (bc *MemoryCache) Incr(ctx context.Context, key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return ErrKeyNotExist + } + + val, err := incr(itm.val) + if err != nil { + return err + } + itm.val = val + return nil +} + +// Decr decreases counter in memory. +func (bc *MemoryCache) Decr(ctx context.Context, key string) error { + bc.Lock() + defer bc.Unlock() + itm, ok := bc.items[key] + if !ok { + return ErrKeyNotExist + } + + val, err := decr(itm.val) + if err != nil { + return err + } + itm.val = val + return nil +} + +// IsExist checks if cache exists in memory. +func (bc *MemoryCache) IsExist(ctx context.Context, key string) (bool, error) { + bc.RLock() + defer bc.RUnlock() + if v, ok := bc.items[key]; ok { + return !v.isExpire(), nil + } + return false, nil +} + +// ClearAll deletes all cache in memory. +func (bc *MemoryCache) ClearAll(context.Context) error { + bc.Lock() + defer bc.Unlock() + bc.items = make(map[string]*MemoryItem) + return nil +} + +// StartAndGC starts memory cache. Checks expiration in every clock time. +func (bc *MemoryCache) StartAndGC(config string) error { + var cf map[string]int + if err := json.Unmarshal([]byte(config), &cf); err != nil { + return berror.Wrapf(err, InvalidMemoryCacheCfg, "invalid config, please check your input: %s", config) + } + if _, ok := cf["interval"]; !ok { + cf = make(map[string]int) + cf["interval"] = DefaultEvery + } + dur := time.Duration(cf["interval"]) * time.Second + bc.Every = cf["interval"] + bc.dur = dur + go bc.vacuum() + return nil +} + +// check expiration. +func (bc *MemoryCache) vacuum() { + bc.RLock() + every := bc.Every + bc.RUnlock() + + if every < 1 { + return + } + for { + <-time.After(bc.dur) + bc.RLock() + if bc.items == nil { + bc.RUnlock() + return + } + bc.RUnlock() + if keys := bc.expiredKeys(); len(keys) != 0 { + bc.clearItems(keys) + } + } +} + +// expiredKeys returns keys list which are expired. +func (bc *MemoryCache) expiredKeys() (keys []string) { + bc.RLock() + defer bc.RUnlock() + for key, itm := range bc.items { + if itm.isExpire() { + keys = append(keys, key) + } + } + return +} + +// ClearItems removes all items who's key is in keys +func (bc *MemoryCache) clearItems(keys []string) { + bc.Lock() + defer bc.Unlock() + for _, key := range keys { + delete(bc.items, key) + } +} + +func init() { + Register("memory", NewMemoryCache) +} diff --git a/testing/assertions.go b/client/cache/module.go similarity index 80% rename from testing/assertions.go rename to client/cache/module.go index 96c5d4ddc9..5a4e499ea1 100644 --- a/testing/assertions.go +++ b/client/cache/module.go @@ -1,10 +1,10 @@ -// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2021 beego // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,4 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -package testing +package cache + +const moduleName = "cache" diff --git a/client/cache/random_expired_cache.go b/client/cache/random_expired_cache.go new file mode 100644 index 0000000000..395a0a92ad --- /dev/null +++ b/client/cache/random_expired_cache.go @@ -0,0 +1,75 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "math/rand" + "sync/atomic" + "time" +) + +// RandomExpireCacheOption implement genreate random time offset expired option +type RandomExpireCacheOption func(*RandomExpireCache) + +// WithRandomExpireOffsetFunc returns a RandomExpireCacheOption that configures the offset function +func WithRandomExpireOffsetFunc(fn func() time.Duration) RandomExpireCacheOption { + return func(cache *RandomExpireCache) { + cache.offset = fn + } +} + +// RandomExpireCache prevent cache batch invalidation +// Cache random time offset expired +type RandomExpireCache struct { + Cache + offset func() time.Duration +} + +// Put random time offset expired +func (rec *RandomExpireCache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { + timeout += rec.offset() + return rec.Cache.Put(ctx, key, val, timeout) +} + +// NewRandomExpireCache return random expire cache struct +func NewRandomExpireCache(adapter Cache, opts ...RandomExpireCacheOption) Cache { + rec := RandomExpireCache{ + Cache: adapter, + offset: defaultExpiredFunc(), + } + for _, fn := range opts { + fn(&rec) + } + return &rec +} + +// defaultExpiredFunc return a func that used to generate random time offset (range: [3s,8s)) expired +func defaultExpiredFunc() func() time.Duration { + const size = 5 + var randTimes [size]time.Duration + for i := range randTimes { + randTimes[i] = time.Duration(i+3) * time.Second + } + // shuffle values + for i := range randTimes { + n := rand.Intn(size) + randTimes[i], randTimes[n] = randTimes[n], randTimes[i] + } + var i uint64 + return func() time.Duration { + return randTimes[atomic.AddUint64(&i, 1)%size] + } +} diff --git a/client/cache/random_expired_cache_test.go b/client/cache/random_expired_cache_test.go new file mode 100644 index 0000000000..07cdb7a110 --- /dev/null +++ b/client/cache/random_expired_cache_test.go @@ -0,0 +1,128 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "fmt" + "math/rand" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRandomExpireCache(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + assert.Nil(t, err) + + cache := NewRandomExpireCache(bm) + // should not be nil + assert.NotNil(t, cache.(*RandomExpireCache).offset) + + timeoutDuration := 3 * time.Second + + if err = cache.Put(context.Background(), "Leon Ding", 22, timeoutDuration); err != nil { + t.Error("set Error", err) + } + + // testing random expire cache + time.Sleep(timeoutDuration + 3 + time.Second) + + if res, _ := cache.IsExist(context.Background(), "Leon Ding"); !res { + t.Error("check err") + } + + if v, _ := cache.Get(context.Background(), "Leon Ding"); v.(int) != 22 { + t.Error("get err") + } + + assert.Nil(t, cache.Delete(context.Background(), "Leon Ding")) + res, _ := cache.IsExist(context.Background(), "Leon Ding") + assert.False(t, res) + + assert.Nil(t, cache.Put(context.Background(), "Leon Ding", "author", timeoutDuration)) + + assert.Nil(t, cache.Delete(context.Background(), "astaxie")) + res, _ = cache.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, cache.Put(context.Background(), "astaxie", "author", timeoutDuration)) + + res, _ = cache.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + v, _ := cache.Get(context.Background(), "astaxie") + assert.Equal(t, "author", v) + + assert.Nil(t, cache.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = cache.IsExist(context.Background(), "astaxie1") + assert.True(t, res) + + vv, _ := cache.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + assert.Equal(t, "author", vv[0]) + assert.Equal(t, "author1", vv[1]) + + vv, err = cache.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + assert.Nil(t, vv[0]) + assert.Equal(t, "author1", vv[1]) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key isn't exist")) +} + +func TestWithRandomExpireOffsetFunc(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + assert.Nil(t, err) + + magic := -time.Duration(rand.Int()) + cache := NewRandomExpireCache(bm, WithRandomExpireOffsetFunc(func() time.Duration { + return magic + })) + // offset should return the magic value + assert.Equal(t, magic, cache.(*RandomExpireCache).offset()) +} + +func ExampleNewRandomExpireCache() { + mc := NewMemoryCache() + // use the default strategy which will generate random time offset (range: [3s,8s)) expired + c := NewRandomExpireCache(mc) + // so the expiration will be [1m3s, 1m8s) + err := c.Put(context.Background(), "hello", "world", time.Minute) + if err != nil { + panic(err) + } + + c = NewRandomExpireCache(mc, + // based on the expiration + WithRandomExpireOffsetFunc(func() time.Duration { + val := rand.Int31n(100) + fmt.Printf("calculate offset") + return time.Duration(val) * time.Second + })) + + // so the expiration will be [1m0s, 1m100s) + err = c.Put(context.Background(), "hello", "world", time.Minute) + if err != nil { + panic(err) + } + + // Output: + // calculate offset +} diff --git a/client/cache/read_through.go b/client/cache/read_through.go new file mode 100644 index 0000000000..573b3255a2 --- /dev/null +++ b/client/cache/read_through.go @@ -0,0 +1,61 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "time" + + "github.com/beego/beego/v2/core/berror" +) + +// readThroughCache is a decorator +// add the read through function to the original Cache function +type readThroughCache struct { + Cache + expiration time.Duration + loadFunc func(ctx context.Context, key string) (any, error) +} + +// NewReadThroughCache create readThroughCache +func NewReadThroughCache(cache Cache, expiration time.Duration, + loadFunc func(ctx context.Context, key string) (any, error), +) (Cache, error) { + if loadFunc == nil { + return nil, berror.Error(InvalidLoadFunc, "loadFunc cannot be nil") + } + return &readThroughCache{ + Cache: cache, + expiration: expiration, + loadFunc: loadFunc, + }, nil +} + +// Get will try to call the LoadFunc to load data if the Cache returns value nil or non-nil error. +func (c *readThroughCache) Get(ctx context.Context, key string) (any, error) { + val, err := c.Cache.Get(ctx, key) + if val == nil || err != nil { + val, err = c.loadFunc(ctx, key) + if err != nil { + return nil, berror.Wrap( + err, LoadFuncFailed, "cache unable to load data") + } + err = c.Cache.Put(ctx, key, val, c.expiration) + if err != nil { + return val, err + } + } + return val, nil +} diff --git a/client/cache/read_through_test.go b/client/cache/read_through_test.go new file mode 100644 index 0000000000..45f12f842b --- /dev/null +++ b/client/cache/read_through_test.go @@ -0,0 +1,171 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/berror" +) + +func TestReadThroughCache_Memory_Get(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + assert.Nil(t, err) + testReadThroughCacheGet(t, bm) +} + +func TestReadThroughCache_file_Get(t *testing.T) { + fc := NewFileCache().(*FileCache) + fc.CachePath = "////aaa" + err := fc.Init() + assert.NotNil(t, err) + fc.CachePath = getTestCacheFilePath() + err = fc.Init() + assert.Nil(t, err) + testReadThroughCacheGet(t, fc) +} + +func testReadThroughCacheGet(t *testing.T, bm Cache) { + testCases := []struct { + name string + key string + value string + cache Cache + wantErr error + }{ + { + name: "Get load err", + key: "key0", + cache: func() Cache { + kvs := map[string]any{"key0": "value0"} + db := &MockOrm{kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + val, er := db.Load(key) + if er != nil { + return nil, er + } + return val, nil + } + c, err := NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + return c + }(), + wantErr: func() error { + err := errors.New("the key not exist") + return berror.Wrap( + err, LoadFuncFailed, "cache unable to load data") + }(), + }, + { + name: "Get cache exist", + key: "key1", + value: "value1", + cache: func() Cache { + keysMap := map[string]int{"key1": 1} + kvs := map[string]any{"key1": "value1"} + db := &MockOrm{keysMap: keysMap, kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + val, er := db.Load(key) + if er != nil { + return nil, er + } + return val, nil + } + c, err := NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + err = c.Put(context.Background(), "key1", "value1", 3*time.Second) + assert.Nil(t, err) + return c + }(), + }, + { + name: "Get loadFunc exist", + key: "key2", + value: "value2", + cache: func() Cache { + keysMap := map[string]int{"key2": 1} + kvs := map[string]any{"key2": "value2"} + db := &MockOrm{keysMap: keysMap, kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + val, er := db.Load(key) + if er != nil { + return nil, er + } + return val, nil + } + c, err := NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + return c + }(), + }, + } + _, err := NewReadThroughCache(bm, 3*time.Second, nil) + assert.Equal(t, berror.Error(InvalidLoadFunc, "loadFunc cannot be nil"), err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := tc.cache + val, err := c.Get(context.Background(), tc.key) + if err != nil { + assert.EqualError(t, tc.wantErr, err.Error()) + return + } + assert.Equal(t, tc.value, val) + }) + } +} + +type MockOrm struct { + keysMap map[string]int + kvs map[string]any +} + +func (m *MockOrm) Load(key string) (any, error) { + _, ok := m.keysMap[key] + if !ok { + return nil, errors.New("the key not exist") + } + return m.kvs[key], nil +} + +func ExampleNewReadThroughCache() { + c := NewMemoryCache() + var err error + c, err = NewReadThroughCache(c, + // expiration, same as the expiration of key + time.Minute, + // load func, how to load data if the key is absent. + // in general, you should load data from database. + func(ctx context.Context, key string) (any, error) { + return fmt.Sprintf("hello, %s", key), nil + }) + if err != nil { + panic(err) + } + + val, err := c.Get(context.Background(), "Beego") + if err != nil { + panic(err) + } + fmt.Print(val) + + // Output: + // hello, Beego +} diff --git a/client/cache/redis/redis.go b/client/cache/redis/redis.go new file mode 100644 index 0000000000..c2b2eaa8a1 --- /dev/null +++ b/client/cache/redis/redis.go @@ -0,0 +1,350 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for cache provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// +// _ "github.com/beego/beego/v2/client/cache/redis" +// "github.com/beego/beego/v2/client/cache" +// +// ) +// +// bm, err := cache.NewCache("redis", `{"conn":"127.0.0.1:11211"}`) +package redis + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/gomodule/redigo/redis" + + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" +) + +const ( + // DefaultKey defines the collection name of redis for the cache adapter. + DefaultKey = "beecacheRedis" + // defaultMaxIdle defines the default max idle connection number. + defaultMaxIdle = 3 + // defaultTimeout defines the default timeout . + defaultTimeout = time.Second * 180 +) + +// Cache is Redis cache adapter. +type Cache struct { + p *redis.Pool // redis connection pool + conninfo string + dbNum int + // key actually is prefix. + key string + password string + maxIdle int + + // skipEmptyPrefix for backward compatible, + // check function associate + // see https://github.com/beego/beego/issues/5248 + skipEmptyPrefix bool + + // Timeout value (less than the redis server's timeout value). + // Timeout used for idle connection + timeout time.Duration +} + +// NewRedisCache creates a new redis cache with default collection name. +func NewRedisCache() cache.Cache { + return &Cache{key: DefaultKey} +} + +// Execute the redis commands. args[0] must be the key name +func (rc *Cache) do(commandName string, args ...interface{}) (interface{}, error) { + args[0] = rc.associate(args[0]) + c := rc.p.Get() + defer func() { + _ = c.Close() + }() + + reply, err := c.Do(commandName, args...) + if err != nil { + return nil, berror.Wrapf(err, cache.RedisCacheCurdFailed, + "could not execute this command: %s", commandName) + } + + return reply, nil +} + +// associate with config key. +func (rc *Cache) associate(originKey interface{}) string { + if rc.key == "" && rc.skipEmptyPrefix { + return fmt.Sprintf("%s", originKey) + } + return fmt.Sprintf("%s:%s", rc.key, originKey) +} + +// Get cache from redis. +func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { + if v, err := rc.do("GET", key); err == nil { + return v, nil + } else { + return nil, err + } +} + +// GetMulti gets cache from redis. +func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { + c := rc.p.Get() + defer func() { + _ = c.Close() + }() + var args []interface{} + for _, key := range keys { + args = append(args, rc.associate(key)) + } + return redis.Values(c.Do("MGET", args...)) +} + +// Put puts cache into redis. +func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { + _, err := rc.do("SETEX", key, int64(timeout/time.Second), val) + return err +} + +// Delete deletes a key's cache in redis. +func (rc *Cache) Delete(ctx context.Context, key string) error { + _, err := rc.do("DEL", key) + return err +} + +// IsExist checks cache's existence in redis. +func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { + v, err := redis.Bool(rc.do("EXISTS", key)) + if err != nil { + return false, err + } + return v, nil +} + +// Incr increases a key's counter in redis. +func (rc *Cache) Incr(ctx context.Context, key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, 1)) + return err +} + +// Decr decreases a key's counter in redis. +func (rc *Cache) Decr(ctx context.Context, key string) error { + _, err := redis.Bool(rc.do("INCRBY", key, -1)) + return err +} + +// ClearAll deletes all cache in the redis collection +// Be careful about this method, because it scans all keys and the delete them one by one +func (rc *Cache) ClearAll(context.Context) error { + cachedKeys, err := rc.Scan(rc.key + ":*") + if err != nil { + return err + } + c := rc.p.Get() + defer func() { + _ = c.Close() + }() + for _, str := range cachedKeys { + if _, err = c.Do("DEL", str); err != nil { + return err + } + } + return err +} + +// Scan scans all keys matching a given pattern. +func (rc *Cache) Scan(pattern string) (keys []string, err error) { + c := rc.p.Get() + defer func() { + _ = c.Close() + }() + var ( + cursor uint64 = 0 // start + result []interface{} + list []string + ) + for { + result, err = redis.Values(c.Do("SCAN", cursor, "MATCH", pattern, "COUNT", 1024)) + if err != nil { + return + } + list, err = redis.Strings(result[1], nil) + if err != nil { + return + } + keys = append(keys, list...) + cursor, err = redis.Uint64(result[0], nil) + if err != nil { + return + } + if cursor == 0 { // over + return + } + } +} + +// StartAndGC starts the redis cache adapter. +// config: must be in this format {"key":"collection key","conn":"connection info","dbNum":"0", "skipEmptyPrefix":"true"} +// Cached items in redis are stored forever, no garbage collection happens +func (rc *Cache) StartAndGC(config string) error { + err := rc.parseConf(config) + if err != nil { + return err + } + + rc.connectInit() + + c := rc.p.Get() + defer func() { + _ = c.Close() + }() + + // test connection + if err = c.Err(); err != nil { + return berror.Wrapf(err, cache.InvalidConnection, + "can not connect to remote redis server, please check the connection info and network state: %s", config) + } + return nil +} + +func (rc *Cache) parseConf(config string) error { + var cf redisConfig + err := json.Unmarshal([]byte(config), &cf) + if err != nil { + return berror.Wrapf(err, cache.InvalidRedisCacheCfg, "could not unmarshal the config: %s", config) + } + + err = cf.parse() + if err != nil { + return err + } + + rc.dbNum = cf.dbNum + rc.key = cf.Key + rc.conninfo = cf.Conn + rc.password = cf.password + rc.maxIdle = cf.maxIdle + rc.timeout = cf.timeout + rc.skipEmptyPrefix = cf.skipEmptyPrefix + + return nil +} + +type redisConfig struct { + DbNum string `json:"dbNum"` + SkipEmptyPrefix string `json:"skipEmptyPrefix"` + Key string `json:"key"` + // Format redis://@: + Conn string `json:"conn"` + MaxIdle string `json:"maxIdle"` + TimeoutStr string `json:"timeout"` + + dbNum int + skipEmptyPrefix bool + maxIdle int + // parse from Conn + password string + // timeout used for idle connection, default is 180 seconds. + timeout time.Duration +} + +// parse parses the config. +// If the necessary settings have not been set, it will return an error. +// It will fill the default values if some fields are missing. +func (cf *redisConfig) parse() error { + if cf.Conn == "" { + return berror.Error(cache.InvalidRedisCacheCfg, "config missing conn field") + } + + // Format redis://@: + cf.Conn = strings.Replace(cf.Conn, "redis://", "", 1) + if i := strings.Index(cf.Conn, "@"); i > -1 { + cf.password = cf.Conn[0:i] + cf.Conn = cf.Conn[i+1:] + } + + if cf.Key == "" { + cf.Key = DefaultKey + } + + if cf.DbNum != "" { + cf.dbNum, _ = strconv.Atoi(cf.DbNum) + } + + if cf.SkipEmptyPrefix != "" { + cf.skipEmptyPrefix, _ = strconv.ParseBool(cf.SkipEmptyPrefix) + } + + if cf.MaxIdle == "" { + cf.maxIdle = defaultMaxIdle + } else { + cf.maxIdle, _ = strconv.Atoi(cf.MaxIdle) + } + + if v, err := time.ParseDuration(cf.TimeoutStr); err == nil { + cf.timeout = v + } else { + cf.timeout = defaultTimeout + } + + return nil +} + +// connect to redis. +func (rc *Cache) connectInit() { + dialFunc := func() (c redis.Conn, err error) { + c, err = redis.Dial("tcp", rc.conninfo) + if err != nil { + return nil, berror.Wrapf(err, cache.DialFailed, + "could not dial to remote server: %s ", rc.conninfo) + } + + if rc.password != "" { + if _, err = c.Do("AUTH", rc.password); err != nil { + _ = c.Close() + return nil, err + } + } + + _, selecterr := c.Do("SELECT", rc.dbNum) + if selecterr != nil { + _ = c.Close() + return nil, selecterr + } + return + } + // initialize a new pool + rc.p = &redis.Pool{ + MaxIdle: rc.maxIdle, + IdleTimeout: rc.timeout, + Dial: dialFunc, + } +} + +func init() { + cache.Register("redis", NewRedisCache) +} diff --git a/client/cache/redis/redis_test.go b/client/cache/redis/redis_test.go new file mode 100644 index 0000000000..d70f1817ce --- /dev/null +++ b/client/cache/redis/redis_test.go @@ -0,0 +1,361 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + "time" + + "github.com/gomodule/redigo/redis" + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" +) + +func TestRedisCache(t *testing.T) { + redisAddr := os.Getenv("REDIS_ADDR") + if redisAddr == "" { + redisAddr = "127.0.0.1:6379" + } + + bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, redisAddr)) + assert.Nil(t, err) + timeoutDuration := 3 * time.Second + + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) + + res, _ := bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + time.Sleep(5 * time.Second) + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", 1, timeoutDuration)) + + val, _ := bm.Get(context.Background(), "astaxie") + v, _ := redis.Int(val, err) + assert.Equal(t, 1, v) + + assert.Nil(t, bm.Incr(context.Background(), "astaxie")) + val, _ = bm.Get(context.Background(), "astaxie") + v, _ = redis.Int(val, err) + assert.Equal(t, 2, v) + + assert.Nil(t, bm.Decr(context.Background(), "astaxie")) + + val, _ = bm.Get(context.Background(), "astaxie") + v, _ = redis.Int(val, err) + assert.Equal(t, 1, v) + assert.Nil(t, bm.Delete(context.Background(), "astaxie")) + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.False(t, res) + + assert.Nil(t, bm.Put(context.Background(), "astaxie", "author", timeoutDuration)) + // test string + + res, _ = bm.IsExist(context.Background(), "astaxie") + assert.True(t, res) + + val, _ = bm.Get(context.Background(), "astaxie") + vs, _ := redis.String(val, err) + assert.Equal(t, "author", vs) + + // test GetMulti + assert.Nil(t, bm.Put(context.Background(), "astaxie1", "author1", timeoutDuration)) + + res, _ = bm.IsExist(context.Background(), "astaxie1") + assert.True(t, res) + + vv, _ := bm.GetMulti(context.Background(), []string{"astaxie", "astaxie1"}) + assert.Equal(t, 2, len(vv)) + vs, _ = redis.String(vv[0], nil) + assert.Equal(t, "author", vs) + + vs, _ = redis.String(vv[1], nil) + assert.Equal(t, "author1", vs) + + vv, _ = bm.GetMulti(context.Background(), []string{"astaxie0", "astaxie1"}) + assert.Nil(t, vv[0]) + + vs, _ = redis.String(vv[1], nil) + assert.Equal(t, "author1", vs) + + // test clear all + assert.Nil(t, bm.ClearAll(context.Background())) +} + +func TestCacheScan(t *testing.T) { + timeoutDuration := 10 * time.Second + + addr := os.Getenv("REDIS_ADDR") + if addr == "" { + addr = "127.0.0.1:6379" + } + + // init + bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, addr)) + + assert.Nil(t, err) + // insert all + for i := 0; i < 100; i++ { + assert.Nil(t, bm.Put(context.Background(), fmt.Sprintf("astaxie%d", i), fmt.Sprintf("author%d", i), timeoutDuration)) + } + time.Sleep(time.Second) + // scan all for the first time + keys, err := bm.(*Cache).Scan(DefaultKey + ":*") + assert.Nil(t, err) + + assert.Equal(t, 100, len(keys), "scan all error") + + // clear all + assert.Nil(t, bm.ClearAll(context.Background())) + + // scan all for the second time + keys, err = bm.(*Cache).Scan(DefaultKey + ":*") + assert.Nil(t, err) + assert.Equal(t, 0, len(keys)) +} + +func TestReadThroughCache_redis_Get(t *testing.T) { + bm, err := cache.NewCache("redis", fmt.Sprintf(`{"conn": "%s"}`, "127.0.0.1:6379")) + assert.Nil(t, err) + + testReadThroughCacheGet(t, bm) +} + +func testReadThroughCacheGet(t *testing.T, bm cache.Cache) { + testCases := []struct { + name string + key string + value string + cache cache.Cache + wantErr error + }{ + { + name: "Get load err", + key: "key0", + cache: func() cache.Cache { + kvs := map[string]any{"key0": "value0"} + db := &MockOrm{kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + v, er := db.Load(key) + if er != nil { + return nil, er + } + val := []byte(v.(string)) + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + return c + }(), + wantErr: func() error { + err := errors.New("the key not exist") + return berror.Wrap( + err, cache.LoadFuncFailed, "cache unable to load data") + }(), + }, + { + name: "Get cache exist", + key: "key1", + value: "value1", + cache: func() cache.Cache { + keysMap := map[string]int{"key1": 1} + kvs := map[string]any{"key1": "value1"} + db := &MockOrm{keysMap: keysMap, kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + v, er := db.Load(key) + if er != nil { + return nil, er + } + val := []byte(v.(string)) + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + err = c.Put(context.Background(), "key1", "value1", 3*time.Second) + assert.Nil(t, err) + return c + }(), + }, + { + name: "Get loadFunc exist", + key: "key2", + value: "value2", + cache: func() cache.Cache { + keysMap := map[string]int{"key2": 1} + kvs := map[string]any{"key2": "value2"} + db := &MockOrm{keysMap: keysMap, kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + v, er := db.Load(key) + if er != nil { + return nil, er + } + val := []byte(v.(string)) + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + return c + }(), + }, + } + _, err := cache.NewReadThroughCache(bm, 3*time.Second, nil) + assert.Equal(t, berror.Error(cache.InvalidLoadFunc, "loadFunc cannot be nil"), err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + bs := []byte(tc.value) + c := tc.cache + val, err := c.Get(context.Background(), tc.key) + if err != nil { + assert.EqualError(t, tc.wantErr, err.Error()) + return + } + assert.Equal(t, bs, val) + }) + } +} + +type MockOrm struct { + keysMap map[string]int + kvs map[string]any +} + +func (m *MockOrm) Load(key string) (any, error) { + _, ok := m.keysMap[key] + if !ok { + return nil, errors.New("the key not exist") + } + return m.kvs[key], nil +} + +func TestCache_associate(t *testing.T) { + testCases := []struct { + name string + skipEmptyPrefix bool + prefix string + input string + wantRes string + }{ + { + name: "skip prefix", + skipEmptyPrefix: true, + prefix: "", + input: "my-key", + wantRes: "my-key", + }, + { + name: "skip prefix but prefix not empty", + skipEmptyPrefix: true, + prefix: "abc", + input: "my-key", + wantRes: "abc:my-key", + }, + { + name: "using empty prefix", + skipEmptyPrefix: false, + prefix: "", + input: "my-key", + wantRes: ":my-key", + }, + { + name: "using prefix", + prefix: "abc", + input: "my-key", + wantRes: "abc:my-key", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := NewRedisCache().(*Cache) + c.skipEmptyPrefix = tc.skipEmptyPrefix + c.key = tc.prefix + res := c.associate(tc.input) + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestCache_parseConf(t *testing.T) { + tests := []struct { + name string + + configStr string + + wantCache Cache + wantErr error + }{ + { + name: "just conn", + configStr: `{ + "conn": "127.0.0.1:6379" +}`, + + wantCache: Cache{ + conninfo: "127.0.0.1:6379", + dbNum: 0, + key: DefaultKey, + password: "", + maxIdle: defaultMaxIdle, + skipEmptyPrefix: false, + timeout: defaultTimeout, + }, + wantErr: nil, + }, + + { + name: "all", + configStr: `{ + "dbNum": "2", + "skipEmptyPrefix": "true", + "key": "mykey", + "conn": "redis://mypwd@127.0.0.1:6379", + "maxIdle": "10", + "timeout": "30s" +}`, + + wantCache: Cache{ + conninfo: "127.0.0.1:6379", + dbNum: 2, + key: "mykey", + password: "mypwd", + maxIdle: 10, + skipEmptyPrefix: true, + timeout: time.Second * 30, + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := Cache{} + err := c.parseConf(tt.configStr) + assert.Equal(t, tt.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tt.wantCache, c) + }) + } +} diff --git a/client/cache/singleflight.go b/client/cache/singleflight.go new file mode 100644 index 0000000000..42b7f84562 --- /dev/null +++ b/client/cache/singleflight.go @@ -0,0 +1,64 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "time" + + "golang.org/x/sync/singleflight" + + "github.com/beego/beego/v2/core/berror" +) + +// SingleflightCache +// This is a very simple decorator mode +type SingleflightCache struct { + Cache + group *singleflight.Group + expiration time.Duration + loadFunc func(ctx context.Context, key string) (any, error) +} + +// NewSingleflightCache create SingleflightCache +func NewSingleflightCache(c Cache, expiration time.Duration, + loadFunc func(ctx context.Context, key string) (any, error), +) (Cache, error) { + if loadFunc == nil { + return nil, berror.Error(InvalidLoadFunc, "loadFunc cannot be nil") + } + return &SingleflightCache{ + Cache: c, + group: &singleflight.Group{}, + expiration: expiration, + loadFunc: loadFunc, + }, nil +} + +// Get In the Get method, single flight is used to load data and write back the cache. +func (s *SingleflightCache) Get(ctx context.Context, key string) (any, error) { + val, err := s.Cache.Get(ctx, key) + if val == nil || err != nil { + val, err, _ = s.group.Do(key, func() (interface{}, error) { + v, er := s.loadFunc(ctx, key) + if er != nil { + return nil, berror.Wrap(er, LoadFuncFailed, "cache unable to load data") + } + er = s.Cache.Put(ctx, key, v, s.expiration) + return v, er + }) + } + return val, err +} diff --git a/client/cache/singleflight_test.go b/client/cache/singleflight_test.go new file mode 100644 index 0000000000..268bb8f98f --- /dev/null +++ b/client/cache/singleflight_test.go @@ -0,0 +1,90 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSingleflight_Memory_Get(t *testing.T) { + bm, err := NewCache("memory", `{"interval":20}`) + assert.Nil(t, err) + + testSingleflightCacheConcurrencyGet(t, bm) +} + +func TestSingleflight_file_Get(t *testing.T) { + fc := NewFileCache().(*FileCache) + fc.CachePath = "////aaa" + err := fc.Init() + assert.NotNil(t, err) + fc.CachePath = getTestCacheFilePath() + err = fc.Init() + assert.Nil(t, err) + + testSingleflightCacheConcurrencyGet(t, fc) +} + +func testSingleflightCacheConcurrencyGet(t *testing.T, bm Cache) { + key, value := "key3", "value3" + db := &MockOrm{keysMap: map[string]int{key: 1}, kvs: map[string]any{key: value}} + c, err := NewSingleflightCache(bm, 10*time.Second, + func(ctx context.Context, key string) (any, error) { + val, er := db.Load(key) + if er != nil { + return nil, er + } + return val, nil + }) + assert.Nil(t, err) + + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + val, err := c.Get(context.Background(), key) + if err != nil { + t.Error(err) + } + assert.Equal(t, value, val) + }() + time.Sleep(1 * time.Millisecond) + } + wg.Wait() +} + +func ExampleNewSingleflightCache() { + c := NewMemoryCache() + c, err := NewSingleflightCache(c, time.Minute, func(ctx context.Context, key string) (any, error) { + return fmt.Sprintf("hello, %s", key), nil + }) + if err != nil { + panic(err) + } + val, err := c.Get(context.Background(), "Beego") + if err != nil { + panic(err) + } + fmt.Print(val) + // Output: + // hello, Beego +} diff --git a/client/cache/ssdb/ssdb.go b/client/cache/ssdb/ssdb.go new file mode 100644 index 0000000000..54558ea3fd --- /dev/null +++ b/client/cache/ssdb/ssdb.go @@ -0,0 +1,201 @@ +package ssdb + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/ssdb/gossdb/ssdb" + + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" +) + +// Cache SSDB adapter +type Cache struct { + conn *ssdb.Client + conninfo []string +} + +// NewSsdbCache creates new ssdb adapter. +func NewSsdbCache() cache.Cache { + return &Cache{} +} + +// Get gets a key's value from memcache. +func (rc *Cache) Get(ctx context.Context, key string) (interface{}, error) { + value, err := rc.conn.Get(key) + if err == nil { + return value, nil + } + return nil, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "could not get value, key: %s", key) +} + +// GetMulti gets one or keys values from ssdb. +func (rc *Cache) GetMulti(ctx context.Context, keys []string) ([]interface{}, error) { + size := len(keys) + values := make([]interface{}, size) + + res, err := rc.conn.Do("multi_get", keys) + if err != nil { + return values, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "multi_get failed, key: %v", keys) + } + + resSize := len(res) + keyIdx := make(map[string]int) + for i := 1; i < resSize; i += 2 { + keyIdx[res[i]] = i + } + + keysErr := make([]string, 0) + for i, ki := range keys { + if _, ok := keyIdx[ki]; !ok { + keysErr = append(keysErr, fmt.Sprintf("key [%s] error: %s", ki, "key not exist")) + continue + } + values[i] = res[keyIdx[ki]+1] + } + + if len(keysErr) != 0 { + return values, berror.Error(cache.MultiGetFailed, strings.Join(keysErr, "; ")) + } + + return values, nil +} + +// DelMulti deletes one or more keys from memcache +func (rc *Cache) DelMulti(keys []string) error { + _, err := rc.conn.Do("multi_del", keys) + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "multi_del failed: %v", keys) +} + +// Put puts value into memcache. +// value: must be of type string +func (rc *Cache) Put(ctx context.Context, key string, val interface{}, timeout time.Duration) error { + v, ok := val.(string) + if !ok { + return berror.Errorf(cache.InvalidSsdbCacheValue, "value must be string: %v", val) + } + var resp []string + var err error + ttl := int(timeout / time.Second) + if ttl < 0 { + resp, err = rc.conn.Do("set", key, v) + } else { + resp, err = rc.conn.Do("setx", key, v, ttl) + } + if err != nil { + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "set or setx failed, key: %s", key) + } + if len(resp) == 2 && resp[0] == "ok" { + return nil + } + return berror.Errorf(cache.SsdbBadResponse, "the response from SSDB server is invalid: %v", resp) +} + +// Delete deletes a value in memcache. +func (rc *Cache) Delete(ctx context.Context, key string) error { + _, err := rc.conn.Del(key) + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "del failed: %s", key) +} + +// Incr increases a key's counter. +func (rc *Cache) Incr(ctx context.Context, key string) error { + _, err := rc.conn.Do("incr", key, 1) + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "increase failed: %s", key) +} + +// Decr decrements a key's counter. +func (rc *Cache) Decr(ctx context.Context, key string) error { + _, err := rc.conn.Do("incr", key, -1) + return berror.Wrapf(err, cache.SsdbCacheCurdFailed, "decrease failed: %s", key) +} + +// IsExist checks if a key exists in memcache. +func (rc *Cache) IsExist(ctx context.Context, key string) (bool, error) { + resp, err := rc.conn.Do("exists", key) + if err != nil { + return false, berror.Wrapf(err, cache.SsdbCacheCurdFailed, "exists failed: %s", key) + } + if len(resp) == 2 && resp[1] == "1" { + return true, nil + } + return false, nil +} + +// ClearAll clears all cached items in ssdb. +// If there are many keys, this method may spent much time. +func (rc *Cache) ClearAll(context.Context) error { + keyStart, keyEnd, limit := "", "", 50 + resp, err := rc.Scan(keyStart, keyEnd, limit) + for err == nil { + size := len(resp) + if size == 1 { + return nil + } + keys := []string{} + for i := 1; i < size; i += 2 { + keys = append(keys, resp[i]) + } + _, e := rc.conn.Do("multi_del", keys) + if e != nil { + return berror.Wrapf(e, cache.SsdbCacheCurdFailed, "multi_del failed: %v", keys) + } + keyStart = resp[size-2] + resp, err = rc.Scan(keyStart, keyEnd, limit) + } + return berror.Wrap(err, cache.SsdbCacheCurdFailed, "scan failed") +} + +// Scan key all cached in ssdb. +func (rc *Cache) Scan(keyStart string, keyEnd string, limit int) ([]string, error) { + resp, err := rc.conn.Do("scan", keyStart, keyEnd, limit) + if err != nil { + return nil, err + } + return resp, nil +} + +// StartAndGC starts the memcache adapter. +// config: must be in the format {"conn":"connection info"}. +// If an error occurs during connection, an error is returned +func (rc *Cache) StartAndGC(config string) error { + var cf map[string]string + err := json.Unmarshal([]byte(config), &cf) + if err != nil { + return berror.Wrapf(err, cache.InvalidSsdbCacheCfg, + "unmarshal this config failed, it must be a valid json string: %s", config) + } + if _, ok := cf["conn"]; !ok { + return berror.Wrapf(err, cache.InvalidSsdbCacheCfg, + "Missing conn field: %s", config) + } + rc.conninfo = strings.Split(cf["conn"], ";") + return rc.connectInit() +} + +// connect to memcache and keep the connection. +func (rc *Cache) connectInit() error { + conninfoArray := strings.Split(rc.conninfo[0], ":") + if len(conninfoArray) < 2 { + return berror.Errorf(cache.InvalidSsdbCacheCfg, "The value of conn should be host:port: %s", rc.conninfo[0]) + } + host := conninfoArray[0] + port, e := strconv.Atoi(conninfoArray[1]) + if e != nil { + return berror.Errorf(cache.InvalidSsdbCacheCfg, "Port is invalid. It must be integer, %s", rc.conninfo[0]) + } + var err error + if rc.conn, err = ssdb.Connect(host, port); err != nil { + return berror.Wrapf(err, cache.InvalidConnection, + "could not connect to SSDB, please check your connection info, network and firewall: %s", rc.conninfo[0]) + } + return nil +} + +func init() { + cache.Register("ssdb", NewSsdbCache) +} diff --git a/client/cache/ssdb/ssdb_test.go b/client/cache/ssdb/ssdb_test.go new file mode 100644 index 0000000000..f8cec529b6 --- /dev/null +++ b/client/cache/ssdb/ssdb_test.go @@ -0,0 +1,211 @@ +package ssdb + +import ( + "context" + "errors" + "fmt" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/core/berror" +) + +func TestSsdbcacheCache(t *testing.T) { + ssdbAddr := os.Getenv("SSDB_ADDR") + if ssdbAddr == "" { + ssdbAddr = "127.0.0.1:8888" + } + + ssdb, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, ssdbAddr)) + assert.Nil(t, err) + + // test put and exist + res, _ := ssdb.IsExist(context.Background(), "ssdb") + assert.False(t, res) + timeoutDuration := 3 * time.Second + // timeoutDuration := -10*time.Second if timeoutDuration is negtive,it means permanent + + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration)) + + res, _ = ssdb.IsExist(context.Background(), "ssdb") + assert.True(t, res) + + // Get test done + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", timeoutDuration)) + + v, _ := ssdb.Get(context.Background(), "ssdb") + assert.Equal(t, "ssdb", v) + + // inc/dec test done + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "2", timeoutDuration)) + + assert.Nil(t, ssdb.Incr(context.Background(), "ssdb")) + + val, _ := ssdb.Get(context.Background(), "ssdb") + v, err = strconv.Atoi(val.(string)) + assert.Nil(t, err) + assert.Equal(t, 3, v) + + assert.Nil(t, ssdb.Decr(context.Background(), "ssdb")) + + // test del + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "3", timeoutDuration)) + + val, _ = ssdb.Get(context.Background(), "ssdb") + v, err = strconv.Atoi(val.(string)) + assert.Equal(t, 3, v) + assert.Nil(t, err) + + assert.Nil(t, ssdb.Delete(context.Background(), "ssdb")) + assert.Nil(t, ssdb.Put(context.Background(), "ssdb", "ssdb", -10*time.Second)) + // test string + + res, _ = ssdb.IsExist(context.Background(), "ssdb") + assert.True(t, res) + + v, _ = ssdb.Get(context.Background(), "ssdb") + assert.Equal(t, "ssdb", v.(string)) + + // test GetMulti done + assert.Nil(t, ssdb.Put(context.Background(), "ssdb1", "ssdb1", -10*time.Second)) + + res, _ = ssdb.IsExist(context.Background(), "ssdb1") + assert.True(t, res) + vv, _ := ssdb.GetMulti(context.Background(), []string{"ssdb", "ssdb1"}) + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "ssdb", vv[0]) + assert.Equal(t, "ssdb1", vv[1]) + + vv, err = ssdb.GetMulti(context.Background(), []string{"ssdb", "ssdb11"}) + + assert.Equal(t, 2, len(vv)) + + assert.Equal(t, "ssdb", vv[0]) + assert.Nil(t, vv[1]) + + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "key not exist")) + + // test clear all done + assert.Nil(t, ssdb.ClearAll(context.Background())) + e1, _ := ssdb.IsExist(context.Background(), "ssdb") + e2, _ := ssdb.IsExist(context.Background(), "ssdb1") + assert.False(t, e1) + assert.False(t, e2) +} + +func TestReadThroughCache_ssdb_Get(t *testing.T) { + bm, err := cache.NewCache("ssdb", fmt.Sprintf(`{"conn": "%s"}`, "127.0.0.1:8888")) + assert.Nil(t, err) + + testReadThroughCacheGet(t, bm) +} + +func testReadThroughCacheGet(t *testing.T, bm cache.Cache) { + testCases := []struct { + name string + key string + value string + cache cache.Cache + wantErr error + }{ + { + name: "Get load err", + key: "key0", + cache: func() cache.Cache { + kvs := map[string]any{"key0": "value0"} + db := &MockOrm{kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + val, er := db.Load(key) + if er != nil { + return nil, er + } + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + return c + }(), + wantErr: func() error { + err := errors.New("the key not exist") + return berror.Wrap( + err, cache.LoadFuncFailed, "cache unable to load data") + }(), + }, + { + name: "Get cache exist", + key: "key1", + value: "value1", + cache: func() cache.Cache { + keysMap := map[string]int{"key1": 1} + kvs := map[string]any{"key1": "value1"} + db := &MockOrm{keysMap: keysMap, kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + val, er := db.Load(key) + if er != nil { + return nil, er + } + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + err = c.Put(context.Background(), "key1", "value1", 3*time.Second) + assert.Nil(t, err) + return c + }(), + }, + { + name: "Get loadFunc exist", + key: "key2", + value: "value2", + cache: func() cache.Cache { + keysMap := map[string]int{"key2": 1} + kvs := map[string]any{"key2": "value2"} + db := &MockOrm{keysMap: keysMap, kvs: kvs} + loadfunc := func(ctx context.Context, key string) (any, error) { + val, er := db.Load(key) + if er != nil { + return nil, er + } + return val, nil + } + c, err := cache.NewReadThroughCache(bm, 3*time.Second, loadfunc) + assert.Nil(t, err) + return c + }(), + }, + } + _, err := cache.NewReadThroughCache(bm, 3*time.Second, nil) + assert.Equal(t, berror.Error(cache.InvalidLoadFunc, "loadFunc cannot be nil"), err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := tc.cache + val, err := c.Get(context.Background(), tc.key) + if err != nil { + assert.EqualError(t, tc.wantErr, err.Error()) + return + } + assert.Equal(t, tc.value, val) + }) + } +} + +type MockOrm struct { + keysMap map[string]int + kvs map[string]any +} + +func (m *MockOrm) Load(key string) (any, error) { + _, ok := m.keysMap[key] + if !ok { + return nil, errors.New("the key not exist") + } + return m.kvs[key], nil +} diff --git a/client/cache/write_delete.go b/client/cache/write_delete.go new file mode 100644 index 0000000000..a7ed3a1e64 --- /dev/null +++ b/client/cache/write_delete.go @@ -0,0 +1,96 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/beego/beego/v2/core/berror" +) + +type WriteDeleteCache struct { + Cache + storeFunc func(ctx context.Context, key string, val any) error +} + +// NewWriteDeleteCache creates a write delete cache pattern decorator. +// The fn is the function that persistent the key and val. +func NewWriteDeleteCache(cache Cache, fn func(ctx context.Context, key string, val any) error) (*WriteDeleteCache, error) { + if fn == nil || cache == nil { + return nil, berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil") + } + + w := &WriteDeleteCache{ + Cache: cache, + storeFunc: fn, + } + return w, nil +} + +func (w *WriteDeleteCache) Set(ctx context.Context, key string, val any) error { + err := w.storeFunc(ctx, key, val) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + return berror.Wrap(err, PersistCacheFailed, fmt.Sprintf("key: %s, val: %v", key, val)) + } + return w.Cache.Delete(ctx, key) +} + +// WriteDoubleDeleteCache creates write double delete cache pattern decorator. +// The fn is the function that persistent the key and val. +// it will delete the key from cache when you call Set function, and wait for interval, it will delete the key from cache one more time. +// This pattern help to reduce the possibility of data inconsistencies, but it's still possible to be inconsistent among database and cache. +type WriteDoubleDeleteCache struct { + Cache + interval time.Duration + timeout time.Duration + storeFunc func(ctx context.Context, key string, val any) error +} + +type WriteDoubleDeleteCacheOption func(c *WriteDoubleDeleteCache) + +func NewWriteDoubleDeleteCache(cache Cache, interval, timeout time.Duration, + fn func(ctx context.Context, key string, val any) error) (*WriteDoubleDeleteCache, error) { + if fn == nil || cache == nil { + return nil, berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil") + } + + return &WriteDoubleDeleteCache{ + Cache: cache, + interval: interval, + timeout: timeout, + storeFunc: fn, + }, nil +} + +func (c *WriteDoubleDeleteCache) Set( + ctx context.Context, key string, val any) error { + err := c.storeFunc(ctx, key, val) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + return berror.Wrap(err, PersistCacheFailed, fmt.Sprintf("key: %s, val: %v", key, val)) + } + time.AfterFunc(c.interval, func() { + rCtx, cancel := context.WithTimeout(context.Background(), c.timeout) + _ = c.Cache.Delete(rCtx, key) + cancel() + }) + err = c.Cache.Delete(ctx, key) + if err != nil { + return berror.Wrap(err, DeleteFailed, fmt.Sprintf("write double delete pattern failed to delete the key: %s", key)) + } + return nil +} diff --git a/client/cache/write_delete_test.go b/client/cache/write_delete_test.go new file mode 100644 index 0000000000..956f721a9e --- /dev/null +++ b/client/cache/write_delete_test.go @@ -0,0 +1,393 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// nolint +package cache + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/berror" +) + +func TestWriteDoubleDeleteCache_Set(t *testing.T) { + mockDbStore := make(map[string]any) + + cancels := make([]func(), 0) + defer func() { + for _, cancel := range cancels { + cancel() + } + }() + timeout := time.Second * 3 + testCases := []struct { + name string + cache Cache + storeFunc func(ctx context.Context, key string, val any) error + ctx context.Context + interval time.Duration + sleepSecond time.Duration + key string + value any + wantErr error + }{ + { + name: "store key/value in db fail", + interval: time.Second, + cache: NewMemoryCache(), + storeFunc: func(ctx context.Context, key string, val any) error { + return errors.New("failed") + }, + ctx: context.TODO(), + wantErr: berror.Wrap(errors.New("failed"), PersistCacheFailed, + fmt.Sprintf("key: %s, val: %v", "", nil)), + }, + { + name: "store key/value success", + interval: time.Second * 2, + sleepSecond: time.Second * 3, + cache: func() Cache { + cache := NewMemoryCache() + err := cache.Put(context.Background(), "hello", "world", time.Second*2) + require.NoError(t, err) + return cache + }(), + storeFunc: func(ctx context.Context, key string, val any) error { + mockDbStore[key] = val + return nil + }, + ctx: context.TODO(), + key: "hello", + value: "world", + }, + { + name: "store key/value timeout", + interval: time.Second * 2, + sleepSecond: time.Second * 3, + cache: func() Cache { + cache := NewMemoryCache() + err := cache.Put(context.Background(), "hello", "hello", time.Second*2) + require.NoError(t, err) + return cache + }(), + storeFunc: func(ctx context.Context, key string, val any) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(3 * time.Second): + mockDbStore[key] = val + return nil + } + }, + ctx: func() context.Context { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + cancels = append(cancels, cancel) + return ctx + + }(), + key: "hello", + value: "hello", + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + cache := tt.cache + c, err := NewWriteDoubleDeleteCache(cache, tt.interval, timeout, tt.storeFunc) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + return + } + + err = c.Set(tt.ctx, tt.key, tt.value) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + return + } + + _, err = c.Get(tt.ctx, tt.key) + assert.Equal(t, ErrKeyNotExist, err) + + err = cache.Put(tt.ctx, tt.key, tt.value, tt.interval) + require.NoError(t, err) + + val, err := c.Get(tt.ctx, tt.key) + require.NoError(t, err) + assert.Equal(t, tt.value, val) + + time.Sleep(tt.sleepSecond) + + _, err = c.Get(tt.ctx, tt.key) + assert.Equal(t, ErrKeyNotExist, err) + }) + } +} + +func TestNewWriteDoubleDeleteCache(t *testing.T) { + underlyingCache := NewMemoryCache() + storeFunc := func(ctx context.Context, key string, val any) error { return nil } + + type args struct { + cache Cache + interval time.Duration + fn func(ctx context.Context, key string, val any) error + } + timeout := time.Second * 3 + tests := []struct { + name string + args args + wantRes *WriteDoubleDeleteCache + wantErr error + }{ + { + name: "nil cache parameters", + args: args{ + cache: nil, + fn: storeFunc, + }, + wantErr: berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil"), + }, + { + name: "nil storeFunc parameters", + args: args{ + cache: underlyingCache, + fn: nil, + }, + wantErr: berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil"), + }, + { + name: "init write-though cache success", + args: args{ + cache: underlyingCache, + fn: storeFunc, + interval: time.Second, + }, + wantRes: &WriteDoubleDeleteCache{ + Cache: underlyingCache, + storeFunc: storeFunc, + interval: time.Second, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewWriteDoubleDeleteCache(tt.args.cache, tt.args.interval, timeout, tt.args.fn) + assert.Equal(t, tt.wantErr, err) + if err != nil { + return + } + }) + } +} + +func ExampleWriteDoubleDeleteCache() { + c := NewMemoryCache() + wtc, err := NewWriteDoubleDeleteCache(c, 1*time.Second, 3*time.Second, func(ctx context.Context, key string, val any) error { + fmt.Printf("write data to somewhere key %s, val %v \n", key, val) + return nil + }) + if err != nil { + panic(err) + } + err = wtc.Set(context.Background(), + "/biz/user/id=1", "I am user 1") + if err != nil { + panic(err) + } + // Output: + // write data to somewhere key /biz/user/id=1, val I am user 1 +} + +func TestWriteDeleteCache_Set(t *testing.T) { + mockDbStore := make(map[string]any) + + cancels := make([]func(), 0) + defer func() { + for _, cancel := range cancels { + cancel() + } + }() + + testCases := []struct { + name string + cache Cache + storeFunc func(ctx context.Context, key string, val any) error + ctx context.Context + key string + value any + wantErr error + before func(Cache) + after func() + }{ + { + name: "store key/value in db fail", + cache: NewMemoryCache(), + storeFunc: func(ctx context.Context, key string, val any) error { + return errors.New("failed") + }, + ctx: context.TODO(), + wantErr: berror.Wrap(errors.New("failed"), PersistCacheFailed, + fmt.Sprintf("key: %s, val: %v", "", nil)), + before: func(cache Cache) {}, + after: func() {}, + }, + { + name: "store key/value success", + cache: NewMemoryCache(), + storeFunc: func(ctx context.Context, key string, val any) error { + mockDbStore[key] = val + return nil + }, + ctx: context.TODO(), + key: "hello", + value: "world", + before: func(cache Cache) { + _ = cache.Put(context.Background(), "hello", "testVal", 10*time.Second) + }, + after: func() { + delete(mockDbStore, "hello") + }, + }, + { + name: "store key/value timeout", + cache: NewMemoryCache(), + storeFunc: func(ctx context.Context, key string, val any) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(3 * time.Second): + mockDbStore[key] = val + return nil + } + + }, + ctx: func() context.Context { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + cancels = append(cancels, cancel) + return ctx + + }(), + key: "hello", + value: nil, + before: func(cache Cache) { + _ = cache.Put(context.Background(), "hello", "testVal", 10*time.Second) + }, + after: func() {}, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + w, err := NewWriteDeleteCache(tt.cache, tt.storeFunc) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + return + } + + tt.before(tt.cache) + defer func() { + tt.after() + }() + + err = w.Set(tt.ctx, tt.key, tt.value) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + return + } + + _, err = w.Get(tt.ctx, tt.key) + assert.Equal(t, ErrKeyNotExist, err) + + vv := mockDbStore[tt.key] + assert.Equal(t, tt.value, vv) + }) + } +} + +func TestNewWriteDeleteCache(t *testing.T) { + underlyingCache := NewMemoryCache() + storeFunc := func(ctx context.Context, key string, val any) error { return nil } + + type args struct { + cache Cache + fn func(ctx context.Context, key string, val any) error + } + tests := []struct { + name string + args args + wantRes *WriteDeleteCache + wantErr error + }{ + { + name: "nil cache parameters", + args: args{ + cache: nil, + fn: storeFunc, + }, + wantErr: berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil"), + }, + { + name: "nil storeFunc parameters", + args: args{ + cache: underlyingCache, + fn: nil, + }, + wantErr: berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil"), + }, + { + name: "init write-though cache success", + args: args{ + cache: underlyingCache, + fn: storeFunc, + }, + wantRes: &WriteDeleteCache{ + Cache: underlyingCache, + storeFunc: storeFunc, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewWriteDeleteCache(tt.args.cache, tt.args.fn) + assert.Equal(t, tt.wantErr, err) + if err != nil { + return + } + }) + } +} + +func ExampleNewWriteDeleteCache() { + c := NewMemoryCache() + wtc, err := NewWriteDeleteCache(c, func(ctx context.Context, key string, val any) error { + fmt.Printf("write data to somewhere key %s, val %v \n", key, val) + return nil + }) + if err != nil { + panic(err) + } + err = wtc.Set(context.Background(), + "/biz/user/id=1", "I am user 1") + if err != nil { + panic(err) + } + // Output: + // write data to somewhere key /biz/user/id=1, val I am user 1 +} diff --git a/client/cache/write_through.go b/client/cache/write_through.go new file mode 100644 index 0000000000..6bf957b822 --- /dev/null +++ b/client/cache/write_through.go @@ -0,0 +1,50 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "fmt" + "time" + + "github.com/beego/beego/v2/core/berror" +) + +type WriteThroughCache struct { + Cache + storeFunc func(ctx context.Context, key string, val any) error +} + +// NewWriteThroughCache creates a write through cache pattern decorator. +// The fn is the function that persistent the key and val. +func NewWriteThroughCache(cache Cache, fn func(ctx context.Context, key string, val any) error) (*WriteThroughCache, error) { + if fn == nil || cache == nil { + return nil, berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil") + } + + w := &WriteThroughCache{ + Cache: cache, + storeFunc: fn, + } + return w, nil +} + +func (w *WriteThroughCache) Set(ctx context.Context, key string, val any, expiration time.Duration) error { + err := w.storeFunc(ctx, key, val) + if err != nil { + return berror.Wrap(err, PersistCacheFailed, fmt.Sprintf("key: %s, val: %v", key, val)) + } + return w.Cache.Put(ctx, key, val, expiration) +} diff --git a/client/cache/write_through_test.go b/client/cache/write_through_test.go new file mode 100644 index 0000000000..d76af643fb --- /dev/null +++ b/client/cache/write_through_test.go @@ -0,0 +1,155 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// nolint +package cache + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/berror" +) + +func TestWriteThoughCache_Set(t *testing.T) { + mockDbStore := make(map[string]any) + + testCases := []struct { + name string + cache Cache + storeFunc func(ctx context.Context, key string, val any) error + key string + value any + wantErr error + }{ + { + name: "store key/value in db fail", + cache: NewMemoryCache(), + storeFunc: func(ctx context.Context, key string, val any) error { + return errors.New("failed") + }, + wantErr: berror.Wrap(errors.New("failed"), PersistCacheFailed, + fmt.Sprintf("key: %s, val: %v", "", nil)), + }, + { + name: "store key/value success", + cache: NewMemoryCache(), + storeFunc: func(ctx context.Context, key string, val any) error { + mockDbStore[key] = val + return nil + }, + key: "hello", + value: "world", + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + w, err := NewWriteThroughCache(tt.cache, tt.storeFunc) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + return + } + + err = w.Set(context.Background(), tt.key, tt.value, 60*time.Second) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + return + } + + val, err := w.Get(context.Background(), tt.key) + assert.Nil(t, err) + assert.Equal(t, tt.value, val) + + vv, ok := mockDbStore[tt.key] + assert.True(t, ok) + assert.Equal(t, tt.value, vv) + }) + } +} + +func TestNewWriteThoughCache(t *testing.T) { + underlyingCache := NewMemoryCache() + storeFunc := func(ctx context.Context, key string, val any) error { return nil } + + type args struct { + cache Cache + fn func(ctx context.Context, key string, val any) error + } + tests := []struct { + name string + args args + wantRes *WriteThroughCache + wantErr error + }{ + { + name: "nil cache parameters", + args: args{ + cache: nil, + fn: storeFunc, + }, + wantErr: berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil"), + }, + { + name: "nil storeFunc parameters", + args: args{ + cache: underlyingCache, + fn: nil, + }, + wantErr: berror.Error(InvalidInitParameters, "cache or storeFunc can not be nil"), + }, + { + name: "init write-though cache success", + args: args{ + cache: underlyingCache, + fn: storeFunc, + }, + wantRes: &WriteThroughCache{ + Cache: underlyingCache, + storeFunc: storeFunc, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewWriteThroughCache(tt.args.cache, tt.args.fn) + assert.Equal(t, tt.wantErr, err) + if err != nil { + return + } + }) + } +} + +func ExampleNewWriteThroughCache() { + c := NewMemoryCache() + wtc, err := NewWriteThroughCache(c, func(ctx context.Context, key string, val any) error { + fmt.Printf("write data to somewhere key %s, val %v \n", key, val) + return nil + }) + if err != nil { + panic(err) + } + err = wtc.Set(context.Background(), + "/biz/user/id=1", "I am user 1", time.Minute) + if err != nil { + panic(err) + } + // Output: + // write data to somewhere key /biz/user/id=1, val I am user 1 +} diff --git a/httplib/README.md b/client/httplib/README.md similarity index 65% rename from httplib/README.md rename to client/httplib/README.md index 97df8e6b96..b1145e7a78 100644 --- a/httplib/README.md +++ b/client/httplib/README.md @@ -1,23 +1,26 @@ # httplib + httplib is an libs help you to curl remote url. # How to use? ## GET + you can use Get to crawl data. - import "github.com/astaxie/beego/httplib" - - str, err := httplib.Get("http://beego.me/").String() + import "github.com/beego/beego/v2/client/httplib" + + str, err := httplib.Get("http://beego.vip/").String() if err != nil { // error } fmt.Println(str) - + ## POST + POST data to remote url - req := httplib.Post("http://beego.me/") + req := httplib.Post("http://beego.vip/") req.Param("username","astaxie") req.Param("password","123456") str, err := req.String() @@ -35,40 +38,39 @@ The default timeout is `60` seconds, function prototype: Example: // GET - httplib.Get("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) - - // POST - httplib.Post("http://beego.me/").SetTimeout(100 * time.Second, 30 * time.Second) + httplib.Get("http://beego.vip/").SetTimeout(100 * time.Second, 30 * time.Second) + // POST + httplib.Post("http://beego.vip/").SetTimeout(100 * time.Second, 30 * time.Second) ## Debug If you want to debug the request info, set the debug on - httplib.Get("http://beego.me/").Debug(true) - + httplib.Get("http://beego.vip/").Debug(true) + ## Set HTTP Basic Auth - str, err := Get("http://beego.me/").SetBasicAuth("user", "passwd").String() + str, err := Get("http://beego.vip/").SetBasicAuth("user", "passwd").String() if err != nil { // error } fmt.Println(str) - + ## Set HTTPS If request url is https, You can set the client support TSL: httplib.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) - -More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config + +More info about the `tls.Config` please visit http://golang.org/pkg/crypto/tls/#Config ## Set HTTP Version some servers need to specify the protocol version of HTTP - httplib.Get("http://beego.me/").SetProtocolVersion("HTTP/1.1") - + httplib.Get("http://beego.vip/").SetProtocolVersion("HTTP/1.1") + ## Set Cookie some http request need setcookie. So set it like this: @@ -76,13 +78,13 @@ some http request need setcookie. So set it like this: cookie := &http.Cookie{} cookie.Name = "username" cookie.Value = "astaxie" - httplib.Get("http://beego.me/").SetCookie(cookie) + httplib.Get("http://beego.vip/").SetCookie(cookie) ## Upload file httplib support mutil file upload, use `req.PostFile()` - req := httplib.Post("http://beego.me/") + req := httplib.Post("http://beego.vip/") req.Param("username","astaxie") req.PostFile("uploadfile1", "httplib.pdf") str, err := req.String() @@ -91,7 +93,6 @@ httplib support mutil file upload, use `req.PostFile()` } fmt.Println(str) - See godoc for further documentation and examples. -* [godoc.org/github.com/astaxie/beego/httplib](https://godoc.org/github.com/astaxie/beego/httplib) +* [godoc.org/github.com/beego/beego/v2/client/httplib](https://godoc.org/github.com/beego/beego/v2/client/httplib) diff --git a/client/httplib/client_option.go b/client/httplib/client_option.go new file mode 100644 index 0000000000..2b8b672c66 --- /dev/null +++ b/client/httplib/client_option.go @@ -0,0 +1,155 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "crypto/tls" + "net/http" + "net/url" + "time" +) + +type ( + ClientOption func(client *Client) + BeegoHTTPRequestOption func(request *BeegoHTTPRequest) +) + +// WithEnableCookie will enable cookie in all subsequent request +func WithEnableCookie(enable bool) ClientOption { + return func(client *Client) { + client.Setting.EnableCookie = enable + } +} + +// WithUserAgent will adds UA in all subsequent request +func WithUserAgent(userAgent string) ClientOption { + return func(client *Client) { + client.Setting.UserAgent = userAgent + } +} + +// WithTLSClientConfig will adds tls config in all subsequent request +func WithTLSClientConfig(config *tls.Config) ClientOption { + return func(client *Client) { + client.Setting.TLSClientConfig = config + } +} + +// WithTransport will set transport field in all subsequent request +func WithTransport(transport http.RoundTripper) ClientOption { + return func(client *Client) { + client.Setting.Transport = transport + } +} + +// WithProxy will set http proxy field in all subsequent request +func WithProxy(proxy func(*http.Request) (*url.URL, error)) ClientOption { + return func(client *Client) { + client.Setting.Proxy = proxy + } +} + +// WithCheckRedirect will specifies the policy for handling redirects in all subsequent request +func WithCheckRedirect(redirect func(req *http.Request, via []*http.Request) error) ClientOption { + return func(client *Client) { + client.Setting.CheckRedirect = redirect + } +} + +// WithHTTPSetting can replace beegoHTTPSeting +func WithHTTPSetting(setting BeegoHTTPSettings) ClientOption { + return func(client *Client) { + client.Setting = setting + } +} + +// WithEnableGzip will enable gzip in all subsequent request +func WithEnableGzip(enable bool) ClientOption { + return func(client *Client) { + client.Setting.Gzip = enable + } +} + +// BeegoHttpRequestOption + +// WithTimeout sets connect time out and read-write time out for BeegoRequest. +func WithTimeout(connectTimeout, readWriteTimeout time.Duration) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.SetTimeout(connectTimeout, readWriteTimeout) + } +} + +// WithHeader adds header item string in request. +func WithHeader(key, value string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header(key, value) + } +} + +// WithCookie adds a cookie to the request. +func WithCookie(cookie *http.Cookie) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header("Cookie", cookie.String()) + } +} + +// WithTokenFactory adds a custom function to set Authorization +func WithTokenFactory(tokenFactory func() string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + t := tokenFactory() + + request.Header("Authorization", t) + } +} + +// WithBasicAuth adds a custom function to set basic auth +func WithBasicAuth(basicAuth func() (string, string)) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + username, password := basicAuth() + request.SetBasicAuth(username, password) + } +} + +// WithFilters will use the filter as the invocation filters +func WithFilters(fcs ...FilterChain) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.SetFilters(fcs...) + } +} + +// WithContentType adds ContentType in header +func WithContentType(contentType string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Header(contentTypeKey, contentType) + } +} + +// WithParam adds query param in to request. +func WithParam(key, value string) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Param(key, value) + } +} + +// WithRetry set retry times and delay for the request +// default is 0 (never retry) +// -1 retry indefinitely (forever) +// Other numbers specify the exact retry amount +func WithRetry(times int, delay time.Duration) BeegoHTTPRequestOption { + return func(request *BeegoHTTPRequest) { + request.Retries(times) + request.RetryDelay(delay) + } +} diff --git a/client/httplib/error_code.go b/client/httplib/error_code.go new file mode 100644 index 0000000000..43a681868e --- /dev/null +++ b/client/httplib/error_code.go @@ -0,0 +1,138 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "github.com/beego/beego/v2/core/berror" +) + +var InvalidUrl = berror.DefineCode(4001001, moduleName, "InvalidUrl", ` +You pass an invalid url to httplib module. Please check your url, be careful about special character. +`) + +var InvalidUrlProtocolVersion = berror.DefineCode(4001002, moduleName, "InvalidUrlProtocolVersion", ` +You pass an invalid protocol version. In practice, we use HTTP/1.0, HTTP/1.1, HTTP/1.2 +But something like HTTP/3.2 is valid for client, and the major version is 3, minor version is 2. +but you must confirm that server support those abnormal protocol version. +`) + +var UnsupportedBodyType = berror.DefineCode(4001003, moduleName, "UnsupportedBodyType", ` +You use an invalid data as request body. +For now, we only support type string and byte[]. +`) + +var InvalidXMLBody = berror.DefineCode(4001004, moduleName, "InvalidXMLBody", ` +You pass invalid data which could not be converted to XML documents. In general, if you pass structure, it works well. +Sometimes you got XML document and you want to make it as request body. So you call XMLBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidYAMLBody = berror.DefineCode(4001005, moduleName, "InvalidYAMLBody", ` +You pass invalid data which could not be converted to YAML documents. In general, if you pass structure, it works well. +Sometimes you got YAML document and you want to make it as request body. So you call YAMLBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidJSONBody = berror.DefineCode(4001006, moduleName, "InvalidJSONBody", ` +You pass invalid data which could not be converted to JSON documents. In general, if you pass structure, it works well. +Sometimes you got JSON document and you want to make it as request body. So you call JSONBody. +If you do this, you got this code. Instead, you should call Header to set Content-type and call Body to set body data. +`) + +var InvalidURLOrMethod = berror.DefineCode(4001007, moduleName, "InvalidURLOrMethod", ` +You pass invalid url or method to httplib module. Please check the url and method, be careful about special characters. +`) + +// start with 5 -------------------------------------------------------------------------- + +var CreateFormFileFailed = berror.DefineCode(5001001, moduleName, "CreateFormFileFailed", ` +In normal case than handling files with BeegoRequest, you should not see this error code. +Unexpected EOF, invalid characters, bad file descriptor may cause this error. +`) + +var ReadFileFailed = berror.DefineCode(5001002, moduleName, "ReadFileFailed", ` +There are several cases that cause this error: +1. file not found. Please check the file name; +2. file not found, but file name is correct. If you use relative file path, it's very possible for you to see this code. +make sure that this file is in correct directory which Beego looks for; +3. Beego don't have the privilege to read this file, please change file mode; +`) + +var CopyFileFailed = berror.DefineCode(5001003, moduleName, "CopyFileFailed", ` +When we try to read file content and then copy it to another writer, and failed. +1. Unexpected EOF; +2. Bad file descriptor; +3. Write conflict; + +Please check your file content, and confirm that file is not processed by other process (or by user manually). +`) + +var CloseFileFailed = berror.DefineCode(5001004, moduleName, "CloseFileFailed", ` +After handling files, Beego try to close file but failed. Usually it was caused by bad file descriptor. +`) + +var SendRequestFailed = berror.DefineCode(5001005, moduleName, "SendRequestRetryExhausted", ` +Beego send HTTP request, but it failed. +If you config retry times, it means that Beego had retried and failed. +When you got this error, there are vary kind of reason: +1. Network unstable and timeout. In this case, sometimes server has received the request. +2. Server error. Make sure that server works well. +3. The request is invalid, which means that you pass some invalid parameter. +`) + +var ReadGzipBodyFailed = berror.DefineCode(5001006, moduleName, "BuildGzipReaderFailed", ` +Beego parse gzip-encode body failed. Usually Beego got invalid response. +Please confirm that server returns gzip data. +`) + +var CreateFileIfNotExistFailed = berror.DefineCode(5001007, moduleName, "CreateFileIfNotExist", ` +Beego want to create file if not exist and failed. +In most cases, it means that Beego doesn't have the privilege to create this file. +Please change file mode to ensure that Beego is able to create files on specific directory. +Or you can run Beego with higher authority. +In some cases, you pass invalid filename. Make sure that the file name is valid on your system. +`) + +var UnmarshalJSONResponseToObjectFailed = berror.DefineCode(5001008, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid json document +`) + +var UnmarshalXMLResponseToObjectFailed = berror.DefineCode(5001009, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid XML document +`) + +var UnmarshalYAMLResponseToObjectFailed = berror.DefineCode(5001010, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +Make sure that: +1. You pass valid structure pointer to the function; +2. The body is valid YAML document +`) + +var UnmarshalResponseToObjectFailed = berror.DefineCode(5001011, moduleName, + "UnmarshalResponseToObjectFailed", ` +Beego trying to unmarshal response's body to structure but failed. +There are several cases that cause this error: +1. You pass valid structure pointer to the function; +2. The body is valid json, Yaml or XML document +`) diff --git a/client/httplib/filter.go b/client/httplib/filter.go new file mode 100644 index 0000000000..5daed64c87 --- /dev/null +++ b/client/httplib/filter.go @@ -0,0 +1,24 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "context" + "net/http" +) + +type FilterChain func(next Filter) Filter + +type Filter func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) diff --git a/client/httplib/filter/log/filter.go b/client/httplib/filter/log/filter.go new file mode 100644 index 0000000000..9d2e09d3bd --- /dev/null +++ b/client/httplib/filter/log/filter.go @@ -0,0 +1,130 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "context" + "io" + "net/http" + "net/http/httputil" + + "github.com/beego/beego/v2/client/httplib" + "github.com/beego/beego/v2/core/logs" +) + +// FilterChainBuilder can build a log filter +type FilterChainBuilder struct { + printableContentTypes []string // only print the body of included mime types of request and response + log func(f interface{}, v ...interface{}) // custom log function +} + +// BuilderOption option constructor +type BuilderOption func(*FilterChainBuilder) + +type logInfo struct { + req []byte + resp []byte + err error +} + +var defaultprintableContentTypes = []string{ + "text/plain", "text/xml", "text/html", "text/csv", + "text/calendar", "text/javascript", "text/javascript", + "text/css", +} + +// NewFilterChainBuilder initialize a filterChainBuilder, pass options to customize +func NewFilterChainBuilder(opts ...BuilderOption) *FilterChainBuilder { + res := &FilterChainBuilder{ + printableContentTypes: defaultprintableContentTypes, + log: logs.Debug, + } + for _, o := range opts { + o(res) + } + + return res +} + +// WithLog return option constructor modify log function +func WithLog(f func(f interface{}, v ...interface{})) BuilderOption { + return func(h *FilterChainBuilder) { + h.log = f + } +} + +// WithprintableContentTypes return option constructor modify printableContentTypes +func WithprintableContentTypes(types []string) BuilderOption { + return func(h *FilterChainBuilder) { + h.printableContentTypes = types + } +} + +// FilterChain can print the request after FilterChain processing and response before processsing +func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + info := &logInfo{} + defer info.print(builder.log) + resp, err := next(ctx, req) + info.err = err + contentType := req.GetRequest().Header.Get("Content-Type") + shouldPrintBody := builder.shouldPrintBody(contentType, req.GetRequest().Body) + dump, err := httputil.DumpRequest(req.GetRequest(), shouldPrintBody) + info.req = dump + if err != nil { + logs.Error(err) + } + if resp != nil { + contentType = resp.Header.Get("Content-Type") + shouldPrintBody = builder.shouldPrintBody(contentType, resp.Body) + dump, err = httputil.DumpResponse(resp, shouldPrintBody) + info.resp = dump + if err != nil { + logs.Error(err) + } + } + return resp, err + } +} + +func (builder *FilterChainBuilder) shouldPrintBody(contentType string, body io.ReadCloser) bool { + if contains(builder.printableContentTypes, contentType) { + return true + } + if body != nil { + logs.Warn("printableContentTypes do not contain %s, if you want to print request and response body please add it.", contentType) + } + return false +} + +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} + +func (info *logInfo) print(log func(f interface{}, v ...interface{})) { + log("Request: ====================") + log("%q", info.req) + log("Response: ===================") + log("%q", info.resp) + if info.err != nil { + log("Error: ======================") + log("%q", info.err) + } +} diff --git a/client/httplib/filter/log/filter_test.go b/client/httplib/filter/log/filter_test.go new file mode 100644 index 0000000000..b17dd395ec --- /dev/null +++ b/client/httplib/filter/log/filter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestFilterChain(t *testing.T) { + next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + time.Sleep(100 * time.Millisecond) + return &http.Response{ + StatusCode: 404, + }, nil + } + builder := NewFilterChainBuilder() + filter := builder.FilterChain(next) + req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") + resp, err := filter(context.Background(), req) + assert.NotNil(t, resp) + assert.Nil(t, err) +} + +func TestContains(t *testing.T) { + jsonType := "application/json" + cases := []struct { + Name string + Types []string + ContentType string + Expected bool + }{ + {"case1", []string{jsonType}, jsonType, true}, + {"case2", []string{"text/plain"}, jsonType, false}, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + if ans := contains(c.Types, c.ContentType); ans != c.Expected { + t.Fatalf("Types: %v, ContentType: %v, expected %v, but %v got", + c.Types, c.ContentType, c.Expected, ans) + } + }) + } +} diff --git a/client/httplib/filter/opentelemetry/filter.go b/client/httplib/filter/opentelemetry/filter.go new file mode 100644 index 0000000000..cee1458dd1 --- /dev/null +++ b/client/httplib/filter/opentelemetry/filter.go @@ -0,0 +1,86 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentelemetry + +import ( + "context" + "net/http" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" + + "github.com/beego/beego/v2/client/httplib" +) + +type CustomSpanFunc func(span trace.Span, ctx context.Context, req *httplib.BeegoHTTPRequest, resp *http.Response, err error) + +type OtelFilterChainBuilder struct { + // TagURL true will tag span with url + tagURL bool + // CustomSpanFunc users are able to custom their span + customSpanFunc CustomSpanFunc +} + +func NewOpenTelemetryFilter(tagURL bool, spanFunc CustomSpanFunc) *OtelFilterChainBuilder { + return &OtelFilterChainBuilder{ + tagURL: tagURL, + customSpanFunc: spanFunc, + } +} + +func (builder *OtelFilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + method := req.GetRequest().Method + + operationName := method + "#" + req.GetRequest().URL.Path + spanCtx, span := otel.Tracer("beego").Start(ctx, operationName) + defer span.End() + + otel.GetTextMapPropagator().Inject(spanCtx, propagation.HeaderCarrier(req.GetRequest().Header)) + + resp, err := next(spanCtx, req) + + if resp != nil { + span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode)) + } + span.SetAttributes(attribute.String("http.method", method)) + span.SetAttributes(attribute.String("peer.hostname", req.GetRequest().URL.Host)) + + span.SetAttributes(attribute.String("http.scheme", req.GetRequest().URL.Scheme)) + span.SetAttributes(attribute.String("span.kind", "client")) + span.SetAttributes(attribute.String("component", "beego")) + + if builder.tagURL { + span.SetAttributes(attribute.String("http.url", req.GetRequest().URL.String())) + } + + if err != nil { + span.SetAttributes(attribute.Bool("error", true)) + span.RecordError(err) + } else if resp != nil && !(resp.StatusCode < 300 && resp.StatusCode >= 200) { + span.SetAttributes(attribute.Bool("error", true)) + } + + span.SetAttributes(attribute.String("peer.address", req.GetRequest().RemoteAddr)) + span.SetAttributes(attribute.String("http.proto", req.GetRequest().Proto)) + + if builder.customSpanFunc != nil { + builder.customSpanFunc(span, ctx, req, resp, err) + } + return resp, err + } +} diff --git a/client/httplib/filter/opentelemetry/filter_test.go b/client/httplib/filter/opentelemetry/filter_test.go new file mode 100644 index 0000000000..4949e37561 --- /dev/null +++ b/client/httplib/filter/opentelemetry/filter_test.go @@ -0,0 +1,46 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentelemetry + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestFilterChainBuilderFilterChain(t *testing.T) { + next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + time.Sleep(100 * time.Millisecond) + return &http.Response{ + StatusCode: 404, + Body: http.NoBody, + }, errors.New("hello") + } + builder := NewOpenTelemetryFilter(true, nil) + filter := builder.FilterChain(next) + req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") + resp, err := filter(context.Background(), req) + + defer resp.Body.Close() + + assert.NotNil(t, resp) + assert.NotNil(t, err) +} diff --git a/client/httplib/filter/opentracing/filter.go b/client/httplib/filter/opentracing/filter.go new file mode 100644 index 0000000000..ec212e5b1b --- /dev/null +++ b/client/httplib/filter/opentracing/filter.go @@ -0,0 +1,79 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + "net/http" + + opentracingKit "github.com/go-kit/kit/tracing/opentracing" + logKit "github.com/go-kit/log" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/log" + + "github.com/beego/beego/v2/client/httplib" +) + +type FilterChainBuilder struct { + // TagURL true will tag span with url + TagURL bool + // CustomSpanFunc users are able to custom their span + CustomSpanFunc func(span opentracing.Span, ctx context.Context, + req *httplib.BeegoHTTPRequest, resp *http.Response, err error) +} + +func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + method := req.GetRequest().Method + + operationName := method + "#" + req.GetRequest().URL.String() + span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) + defer span.Finish() + + inject := opentracingKit.ContextToHTTP(opentracing.GlobalTracer(), logKit.NewNopLogger()) + inject(spanCtx, req.GetRequest()) + resp, err := next(spanCtx, req) + + if resp != nil { + span.SetTag("http.status_code", resp.StatusCode) + } + span.SetTag("http.method", method) + span.SetTag("peer.hostname", req.GetRequest().URL.Host) + + span.SetTag("http.scheme", req.GetRequest().URL.Scheme) + span.SetTag("span.kind", "client") + span.SetTag("component", "beego") + + if builder.TagURL { + span.SetTag("http.url", req.GetRequest().URL.String()) + } + span.LogFields(log.String("http.url", req.GetRequest().URL.String())) + + if err != nil { + span.SetTag("error", true) + span.LogFields(log.String("message", err.Error())) + } else if resp != nil && !(resp.StatusCode < 300 && resp.StatusCode >= 200) { + span.SetTag("error", true) + } + + span.SetTag("peer.address", req.GetRequest().RemoteAddr) + span.SetTag("http.proto", req.GetRequest().Proto) + + if builder.CustomSpanFunc != nil { + builder.CustomSpanFunc(span, ctx, req, resp, err) + } + return resp, err + } +} diff --git a/client/httplib/filter/opentracing/filter_test.go b/client/httplib/filter/opentracing/filter_test.go new file mode 100644 index 0000000000..a9b9cbb033 --- /dev/null +++ b/client/httplib/filter/opentracing/filter_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestFilterChainBuilderFilterChain(t *testing.T) { + next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + time.Sleep(100 * time.Millisecond) + return &http.Response{ + StatusCode: 404, + }, errors.New("hello") + } + builder := &FilterChainBuilder{ + TagURL: true, + } + filter := builder.FilterChain(next) + req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") + resp, err := filter(context.Background(), req) + assert.NotNil(t, resp) + assert.NotNil(t, err) +} diff --git a/client/httplib/filter/prometheus/filter.go b/client/httplib/filter/prometheus/filter.go new file mode 100644 index 0000000000..e93b229807 --- /dev/null +++ b/client/httplib/filter/prometheus/filter.go @@ -0,0 +1,85 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "context" + "net/http" + "strconv" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/beego/beego/v2/client/httplib" +) + +type FilterChainBuilder struct { + AppName string + ServerName string + RunMode string +} + +var ( + summaryVec prometheus.ObserverVec + initSummaryVec sync.Once +) + +func (builder *FilterChainBuilder) FilterChain(next httplib.Filter) httplib.Filter { + initSummaryVec.Do(func() { + summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "remote_http_request", + ConstLabels: map[string]string{ + "server": builder.ServerName, + "env": builder.RunMode, + "appname": builder.AppName, + }, + Help: "The statics info for remote http requests", + }, []string{"proto", "scheme", "method", "host", "path", "status", "isError"}) + + prometheus.MustRegister(summaryVec) + }) + + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + startTime := time.Now() + resp, err := next(ctx, req) + endTime := time.Now() + go builder.report(startTime, endTime, ctx, req, resp, err) + return resp, err + } +} + +func (builder *FilterChainBuilder) report(startTime time.Time, endTime time.Time, + ctx context.Context, req *httplib.BeegoHTTPRequest, resp *http.Response, err error) { + + proto := req.GetRequest().Proto + + scheme := req.GetRequest().URL.Scheme + method := req.GetRequest().Method + + host := req.GetRequest().URL.Host + path := req.GetRequest().URL.Path + + status := -1 + if resp != nil { + status = resp.StatusCode + } + + dur := int(endTime.Sub(startTime) / time.Millisecond) + + summaryVec.WithLabelValues(proto, scheme, method, host, path, + strconv.Itoa(status), strconv.FormatBool(err != nil)).Observe(float64(dur)) +} diff --git a/client/httplib/filter/prometheus/filter_test.go b/client/httplib/filter/prometheus/filter_test.go new file mode 100644 index 0000000000..4a7b29f2cf --- /dev/null +++ b/client/httplib/filter/prometheus/filter_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestFilterChainBuilderFilterChain(t *testing.T) { + next := func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + time.Sleep(100 * time.Millisecond) + return &http.Response{ + StatusCode: 404, + }, nil + } + builder := &FilterChainBuilder{} + filter := builder.FilterChain(next) + req := httplib.Get("https://github.com/notifications?query=repo%3Aastaxie%2Fbeego") + resp, err := filter(context.Background(), req) + assert.NotNil(t, resp) + assert.Nil(t, err) +} diff --git a/client/httplib/http_response.go b/client/httplib/http_response.go new file mode 100644 index 0000000000..2f2ca2f82e --- /dev/null +++ b/client/httplib/http_response.go @@ -0,0 +1,39 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "bytes" + "encoding/json" + "io" + "net/http" +) + +// NewHttpResponseWithJsonBody will try to convert the data to json format +// usually you only use this when you want to mock http Response +func NewHttpResponseWithJsonBody(data interface{}) *http.Response { + var body []byte + if str, ok := data.(string); ok { + body = []byte(str) + } else if bts, ok := data.([]byte); ok { + body = bts + } else { + body, _ = json.Marshal(data) + } + return &http.Response{ + ContentLength: int64(len(body)), + Body: io.NopCloser(bytes.NewReader(body)), + } +} diff --git a/client/httplib/http_response_test.go b/client/httplib/http_response_test.go new file mode 100644 index 0000000000..a62dd42cc3 --- /dev/null +++ b/client/httplib/http_response_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewHttpResponseWithJsonBody(t *testing.T) { + // string + resp := NewHttpResponseWithJsonBody("{}") + assert.Equal(t, int64(2), resp.ContentLength) + + resp = NewHttpResponseWithJsonBody([]byte("{}")) + assert.Equal(t, int64(2), resp.ContentLength) + + resp = NewHttpResponseWithJsonBody(&user{ + Name: "Tom", + }) + assert.True(t, resp.ContentLength > 0) +} diff --git a/client/httplib/httpclient.go b/client/httplib/httpclient.go new file mode 100644 index 0000000000..f3ac4afe70 --- /dev/null +++ b/client/httplib/httpclient.go @@ -0,0 +1,173 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "bytes" + "io" + "net/http" +) + +// Client provides an HTTP client supporting chain call +type Client struct { + Name string + Endpoint string + CommonOpts []BeegoHTTPRequestOption + + Setting BeegoHTTPSettings +} + +// HTTPResponseCarrier If value implement HTTPResponseCarrier. http.Response will pass to SetHTTPResponse +type HTTPResponseCarrier interface { + SetHTTPResponse(resp *http.Response) +} + +// HTTPBodyCarrier If value implement HTTPBodyCarrier. http.Response.Body will pass to SetReader +type HTTPBodyCarrier interface { + SetReader(r io.ReadCloser) +} + +// HTTPBytesCarrier If value implement HTTPBytesCarrier. +// All the byte in http.Response.Body will pass to SetBytes +type HTTPBytesCarrier interface { + SetBytes(bytes []byte) +} + +// HTTPStatusCarrier If value implement HTTPStatusCarrier. http.Response.StatusCode will pass to SetStatusCode +type HTTPStatusCarrier interface { + SetStatusCode(status int) +} + +// HTTPHeadersCarrier If value implement HttpHeaderCarrier. http.Response.Header will pass to SetHeader +type HTTPHeadersCarrier interface { + SetHeader(header map[string][]string) +} + +// NewClient return a new http client +func NewClient(name string, endpoint string, opts ...ClientOption) (*Client, error) { + res := &Client{ + Name: name, + Endpoint: endpoint, + } + setting := GetDefaultSetting() + res.Setting = setting + for _, o := range opts { + o(res) + } + return res, nil +} + +func (c *Client) customReq(req *BeegoHTTPRequest, opts []BeegoHTTPRequestOption) { + req.Setting(c.Setting) + opts = append(c.CommonOpts, opts...) + for _, o := range opts { + o(req) + } +} + +// handleResponse try to parse body to meaningful value +func (c *Client) handleResponse(value interface{}, req *BeegoHTTPRequest) error { + // make sure req.resp is not nil + _, err := req.Bytes() + if err != nil { + return err + } + + err = c.handleCarrier(value, req) + if err != nil { + return err + } + + return req.ToValue(value) +} + +// handleCarrier set http data to value +func (c *Client) handleCarrier(value interface{}, req *BeegoHTTPRequest) error { + if value == nil { + return nil + } + + if carrier, ok := value.(HTTPResponseCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + req.resp.Body = io.NopCloser(bytes.NewReader(b)) + carrier.SetHTTPResponse(req.resp) + } + if carrier, ok := value.(HTTPBodyCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + reader := io.NopCloser(bytes.NewReader(b)) + carrier.SetReader(reader) + } + if carrier, ok := value.(HTTPBytesCarrier); ok { + b, err := req.Bytes() + if err != nil { + return err + } + carrier.SetBytes(b) + } + if carrier, ok := value.(HTTPStatusCarrier); ok { + carrier.SetStatusCode(req.resp.StatusCode) + } + if carrier, ok := value.(HTTPHeadersCarrier); ok { + carrier.SetHeader(req.resp.Header) + } + return nil +} + +// Get Send a GET request and try to give its result value +func (c *Client) Get(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Get(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} + +// Post Send a POST request and try to give its result value +func (c *Client) Post(value interface{}, path string, body interface{}, opts ...BeegoHTTPRequestOption) error { + req := Post(c.Endpoint + path) + c.customReq(req, opts) + if body != nil { + req = req.Body(body) + } + return c.handleResponse(value, req) +} + +// Put Send a Put request and try to give its result value +func (c *Client) Put(value interface{}, path string, body interface{}, opts ...BeegoHTTPRequestOption) error { + req := Put(c.Endpoint + path) + c.customReq(req, opts) + if body != nil { + req = req.Body(body) + } + return c.handleResponse(value, req) +} + +// Delete Send a Delete request and try to give its result value +func (c *Client) Delete(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Delete(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} + +// Head Send a Head request and try to give its result value +func (c *Client) Head(value interface{}, path string, opts ...BeegoHTTPRequestOption) error { + req := Head(c.Endpoint + path) + c.customReq(req, opts) + return c.handleResponse(value, req) +} diff --git a/client/httplib/httpclient_test.go b/client/httplib/httpclient_test.go new file mode 100644 index 0000000000..b7dc8769b6 --- /dev/null +++ b/client/httplib/httpclient_test.go @@ -0,0 +1,294 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "encoding/json" + "encoding/xml" + "io" + "net" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "gopkg.in/yaml.v3" + + "github.com/stretchr/testify/assert" +) + +func TestNewClient(t *testing.T) { + client, err := NewClient("test1", "http://beego.vip", WithEnableCookie(true)) + assert.NoError(t, err) + assert.NotNil(t, client) + assert.Equal(t, true, client.Setting.EnableCookie) +} + +type slideShowResponse struct { + Resp *http.Response `json:"resp,omitempty"` + bytes []byte `json:"bytes,omitempty"` + StatusCode int `json:"status_code,omitempty"` + Body io.ReadCloser `json:"body,omitempty"` + Header map[string][]string `json:"header,omitempty"` + + Slideshow slideshow `json:"slideshow,omitempty" yaml:"slideshow" xml:"slideshow"` +} + +func (r *slideShowResponse) SetHTTPResponse(resp *http.Response) { + r.Resp = resp +} + +func (r *slideShowResponse) SetBytes(bytes []byte) { + r.bytes = bytes +} + +func (r *slideShowResponse) SetReader(reader io.ReadCloser) { + r.Body = reader +} + +func (r *slideShowResponse) SetStatusCode(status int) { + r.StatusCode = status +} + +func (r *slideShowResponse) SetHeader(header map[string][]string) { + r.Header = header +} + +func (r *slideShowResponse) String() string { + return string(r.bytes) +} + +type slideshow struct { + //XMLName xml.Name `xml:"slideshow"` + + Title string `json:"title" yaml:"title" xml:"title,attr"` + Author string `json:"author" yaml:"author" xml:"author,attr"` + Date string `json:"date" yaml:"date" xml:"date,attr"` + Slides []slide `json:"slides" yaml:"slides" xml:"slide"` +} + +type slide struct { + XMLName xml.Name `xml:"slide"` + + Title string `json:"title" yaml:"title" xml:"title"` +} + +type ClientTestSuite struct { + suite.Suite + l net.Listener +} + +func (c *ClientTestSuite) SetupSuite() { + listener, err := net.Listen("tcp", ":8080") + require.NoError(c.T(), err) + c.l = listener + + handler := http.NewServeMux() + handler.HandleFunc("/json", func(writer http.ResponseWriter, request *http.Request) { + data, _ := json.Marshal(slideshow{}) + _, _ = writer.Write(data) + }) + + ssr := slideShowResponse{ + Slideshow: slideshow{ + Title: "Sample Slide Show", + Slides: []slide{ + { + Title: "Content", + }, + { + Title: "Overview", + }, + }, + }, + } + + handler.HandleFunc("/req2resp", func(writer http.ResponseWriter, request *http.Request) { + data, _ := io.ReadAll(request.Body) + _, _ = writer.Write(data) + }) + + handler.HandleFunc("/get", func(writer http.ResponseWriter, request *http.Request) { + data, _ := json.Marshal(ssr) + _, _ = writer.Write(data) + }) + + handler.HandleFunc("/get/xml", func(writer http.ResponseWriter, request *http.Request) { + data, err := xml.Marshal(ssr.Slideshow) + require.NoError(c.T(), err) + _, _ = writer.Write(data) + }) + + handler.HandleFunc("/get/yaml", func(writer http.ResponseWriter, request *http.Request) { + data, _ := yaml.Marshal(ssr) + _, _ = writer.Write(data) + }) + + go func() { + _ = http.Serve(listener, handler) + }() +} + +func (c *ClientTestSuite) TearDownSuite() { + _ = c.l.Close() +} + +func TestClient(t *testing.T) { + suite.Run(t, &ClientTestSuite{}) +} + +func (c *ClientTestSuite) TestClientHandleCarrier() { + t := c.T() + v := "beego" + client, err := NewClient("test", "http://localhost:8080/", + WithUserAgent(v)) + require.NoError(t, err) + resp := &slideShowResponse{} + err = client.Get(resp, "/json") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + assert.NotNil(t, resp.Resp) + assert.NotNil(t, resp.Body) + assert.Equal(t, "48", resp.Header["Content-Length"][0]) + assert.Equal(t, 200, resp.StatusCode) + + b, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, 48, len(b)) + assert.Equal(t, resp.String(), string(b)) +} + +func (c *ClientTestSuite) TestClientGet() { + t := c.T() + client, err := NewClient("test", "http://localhost:8080/") + if err != nil { + t.Fatal(err) + } + + // json + var s slideShowResponse + err = client.Get(&s, "/get") + require.NoError(t, err) + assert.Equal(t, "Sample Slide Show", s.Slideshow.Title) + assert.Equal(t, 2, len(s.Slideshow.Slides)) + assert.Equal(t, "Overview", s.Slideshow.Slides[1].Title) + + // xml + var ss slideshow + err = client.Get(&ss, "/get/xml") + require.NoError(t, err) + assert.Equal(t, "Sample Slide Show", ss.Title) + assert.Equal(t, 2, len(ss.Slides)) + assert.Equal(t, "Overview", ss.Slides[1].Title) + + // yaml + s = slideShowResponse{} + err = client.Get(&s, "/get/yaml") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "Sample Slide Show", s.Slideshow.Title) + assert.Equal(t, 2, len(s.Slideshow.Slides)) + assert.Equal(t, "Overview", s.Slideshow.Slides[1].Title) +} + +func (c *ClientTestSuite) TestClientPost() { + t := c.T() + client, err := NewClient("test", "http://localhost:8080") + require.NoError(t, err) + + input := slideShowResponse{ + Slideshow: slideshow{ + Title: "Sample Slide Show", + Slides: []slide{ + { + Title: "Content", + }, + { + Title: "Overview", + }, + }, + }, + } + + jsonStr, err := json.Marshal(input) + require.NoError(t, err) + resp := slideShowResponse{} + err = client.Post(&resp, "/req2resp", jsonStr) + require.NoError(t, err) + assert.Equal(t, input.Slideshow, resp.Slideshow) + assert.Equal(t, http.MethodPost, resp.Resp.Request.Method) +} + +func (c *ClientTestSuite) TestClientPut() { + t := c.T() + client, err := NewClient("test", "http://localhost:8080") + require.NoError(t, err) + + input := slideShowResponse{ + Slideshow: slideshow{ + Title: "Sample Slide Show", + Slides: []slide{ + { + Title: "Content", + }, + { + Title: "Overview", + }, + }, + }, + } + + jsonStr, err := json.Marshal(input) + require.NoError(t, err) + resp := slideShowResponse{} + err = client.Put(&resp, "/req2resp", jsonStr) + require.NoError(t, err) + assert.Equal(t, input.Slideshow, resp.Slideshow) + assert.Equal(t, http.MethodPut, resp.Resp.Request.Method) +} + +func (c *ClientTestSuite) TestClientDelete() { + t := c.T() + client, err := NewClient("test", "http://localhost:8080") + require.NoError(t, err) + + resp := &slideShowResponse{} + err = client.Delete(resp, "/req2resp") + require.NoError(t, err) + defer resp.Resp.Body.Close() + + assert.NotNil(t, resp) + assert.Equal(t, http.MethodDelete, resp.Resp.Request.Method) +} + +func (c *ClientTestSuite) TestClientHead() { + t := c.T() + client, err := NewClient("test", "http://localhost:8080") + require.NoError(t, err) + resp := &slideShowResponse{} + err = client.Head(resp, "/req2resp") + if err != nil { + t.Fatal(err) + } + defer resp.Resp.Body.Close() + assert.NotNil(t, resp) + assert.Equal(t, http.MethodHead, resp.Resp.Request.Method) +} diff --git a/httplib/httplib.go b/client/httplib/httplib.go similarity index 50% rename from httplib/httplib.go rename to client/httplib/httplib.go index 074cf66159..fca981cea5 100644 --- a/httplib/httplib.go +++ b/client/httplib/httplib.go @@ -15,9 +15,9 @@ // Package httplib is used as http.Client // Usage: // -// import "github.com/astaxie/beego/httplib" +// import "github.com/beego/beego/v2/client/httplib" // -// b := httplib.Post("http://beego.me/") +// b := httplib.Post("http://beego.vip/") // b.Param("username","astaxie") // b.Param("password","123456") // b.PostFile("uploadfile1", "httplib.pdf") @@ -27,79 +27,62 @@ // t.Fatal(err) // } // fmt.Println(str) -// -// more docs http://beego.me/docs/module/httplib.md package httplib import ( "bytes" "compress/gzip" + "context" "crypto/tls" "encoding/json" "encoding/xml" "io" - "io/ioutil" - "log" "mime/multipart" "net" "net/http" - "net/http/cookiejar" - "net/http/httputil" "net/url" "os" + "path/filepath" "strings" - "sync" "time" - "gopkg.in/yaml.v2" -) -var defaultSetting = BeegoHTTPSettings{ - UserAgent: "beegoServer", - ConnectTimeout: 60 * time.Second, - ReadWriteTimeout: 60 * time.Second, - Gzip: true, - DumpBody: true, -} + "gopkg.in/yaml.v3" -var defaultCookieJar http.CookieJar -var settingMutex sync.Mutex + "github.com/beego/beego/v2/core/berror" + "github.com/beego/beego/v2/core/logs" +) -// createDefaultCookie creates a global cookiejar to store cookies. -func createDefaultCookie() { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultCookieJar, _ = cookiejar.New(nil) -} +const contentTypeKey = "Content-Type" -// SetDefaultSetting Overwrite default settings -func SetDefaultSetting(setting BeegoHTTPSettings) { - settingMutex.Lock() - defer settingMutex.Unlock() - defaultSetting = setting +// it will be the last filter and execute request.Do +var doRequestFilter = func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return req.doRequest(ctx) } -// NewBeegoRequest return *BeegoHttpRequest with specific method +// NewBeegoRequest returns *BeegoHttpRequest with specific method +// TODO add error as return value +// I think if we don't return error +// users are hard to check whether we create Beego request successfully func NewBeegoRequest(rawurl, method string) *BeegoHTTPRequest { - var resp http.Response - u, err := url.Parse(rawurl) + return NewBeegoRequestWithCtx(context.Background(), rawurl, method) +} + +// NewBeegoRequestWithCtx returns a new BeegoHTTPRequest given a method, URL +func NewBeegoRequestWithCtx(ctx context.Context, rawurl, method string) *BeegoHTTPRequest { + req, err := http.NewRequestWithContext(ctx, method, rawurl, nil) if err != nil { - log.Println("Httplib:", err) - } - req := http.Request{ - URL: u, - Method: method, - Header: make(http.Header), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, + logs.Error("%+v", berror.Wrapf(err, InvalidURLOrMethod, "invalid raw url or method: %s %s", rawurl, method)) } return &BeegoHTTPRequest{ url: rawurl, - req: &req, + req: req, params: map[string][]string{}, files: map[string]string{}, setting: defaultSetting, - resp: &resp, + resp: &http.Response{}, + copyBody: func() io.ReadCloser { + return nil + }, } } @@ -128,23 +111,7 @@ func Head(url string) *BeegoHTTPRequest { return NewBeegoRequest(url, "HEAD") } -// BeegoHTTPSettings is the http.Client setting -type BeegoHTTPSettings struct { - ShowDebug bool - UserAgent string - ConnectTimeout time.Duration - ReadWriteTimeout time.Duration - TLSClientConfig *tls.Config - Proxy func(*http.Request) (*url.URL, error) - Transport http.RoundTripper - CheckRedirect func(req *http.Request, via []*http.Request) error - EnableCookie bool - Gzip bool - DumpBody bool - Retries int // if set to -1 means will retry forever -} - -// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. +// BeegoHTTPRequest provides more useful methods than http.Request for requesting an url. type BeegoHTTPRequest struct { url string req *http.Request @@ -152,16 +119,18 @@ type BeegoHTTPRequest struct { files map[string]string setting BeegoHTTPSettings resp *http.Response - body []byte - dump []byte + // body the response body, not the request body + body []byte + // copyBody support retry strategy to avoid copy request body + copyBody func() io.ReadCloser } -// GetRequest return the request object +// GetRequest returns the request object func (b *BeegoHTTPRequest) GetRequest() *http.Request { return b.req } -// Setting Change request settings +// Setting changes request settings func (b *BeegoHTTPRequest) Setting(setting BeegoHTTPSettings) *BeegoHTTPRequest { b.setting = setting return b @@ -185,32 +154,21 @@ func (b *BeegoHTTPRequest) SetUserAgent(useragent string) *BeegoHTTPRequest { return b } -// Debug sets show debug or not when executing request. -func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { - b.setting.ShowDebug = isdebug - return b -} - // Retries sets Retries times. -// default is 0 means no retried. -// -1 means retried forever. -// others means retried times. +// default is 0 (never retry) +// -1 retry indefinitely (forever) +// Other numbers specify the exact retry amount func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest { b.setting.Retries = times return b } -// DumpBody setting whether need to Dump the Body. -func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { - b.setting.DumpBody = isdump +// RetryDelay sets the time to sleep between reconnection attempts +func (b *BeegoHTTPRequest) RetryDelay(delay time.Duration) *BeegoHTTPRequest { + b.setting.RetryDelay = delay return b } -// DumpRequest return the DumpRequest -func (b *BeegoHTTPRequest) DumpRequest() []byte { - return b.dump -} - // SetTimeout sets connect time out and read-write time out for BeegoRequest. func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Duration) *BeegoHTTPRequest { b.setting.ConnectTimeout = connectTimeout @@ -218,13 +176,13 @@ func (b *BeegoHTTPRequest) SetTimeout(connectTimeout, readWriteTimeout time.Dura return b } -// SetTLSClientConfig sets tls connection configurations if visiting https url. +// SetTLSClientConfig sets TLS connection configuration if visiting HTTPS url. func (b *BeegoHTTPRequest) SetTLSClientConfig(config *tls.Config) *BeegoHTTPRequest { b.setting.TLSClientConfig = config return b } -// Header add header item string in request. +// Header adds header item string in request. func (b *BeegoHTTPRequest) Header(key, value string) *BeegoHTTPRequest { b.req.Header.Set(key, value) return b @@ -236,10 +194,10 @@ func (b *BeegoHTTPRequest) SetHost(host string) *BeegoHTTPRequest { return b } -// SetProtocolVersion Set the protocol version for incoming requests. -// Client requests always use HTTP/1.1. +// SetProtocolVersion sets the protocol version for incoming requests. +// Client requests always use HTTP/1.1 func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { - if len(vers) == 0 { + if vers == "" { vers = "HTTP/1.1" } @@ -248,30 +206,31 @@ func (b *BeegoHTTPRequest) SetProtocolVersion(vers string) *BeegoHTTPRequest { b.req.Proto = vers b.req.ProtoMajor = major b.req.ProtoMinor = minor + return b } - + logs.Error("%+v", berror.Errorf(InvalidUrlProtocolVersion, "invalid protocol: %s", vers)) return b } -// SetCookie add cookie into request. +// SetCookie adds a cookie to the request. func (b *BeegoHTTPRequest) SetCookie(cookie *http.Cookie) *BeegoHTTPRequest { b.req.Header.Add("Cookie", cookie.String()) return b } -// SetTransport set the setting transport +// SetTransport sets the transport field func (b *BeegoHTTPRequest) SetTransport(transport http.RoundTripper) *BeegoHTTPRequest { b.setting.Transport = transport return b } -// SetProxy set the http proxy +// SetProxy sets the HTTP proxy // example: // // func(req *http.Request) (*url.URL, error) { -// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") -// return u, nil -// } +// u, _ := url.ParseRequestURI("http://127.0.0.1:8118") +// return u, nil +// } func (b *BeegoHTTPRequest) SetProxy(proxy func(*http.Request) (*url.URL, error)) *BeegoHTTPRequest { b.setting.Proxy = proxy return b @@ -286,6 +245,24 @@ func (b *BeegoHTTPRequest) SetCheckRedirect(redirect func(req *http.Request, via return b } +// SetFilters will use the filter as the invocation filters +func (b *BeegoHTTPRequest) SetFilters(fcs ...FilterChain) *BeegoHTTPRequest { + b.setting.FilterChains = fcs + return b +} + +// AddFilters adds filter +func (b *BeegoHTTPRequest) AddFilters(fcs ...FilterChain) *BeegoHTTPRequest { + b.setting.FilterChains = append(b.setting.FilterChains, fcs...) + return b +} + +// SetEscapeHTML is used to set the flag whether escape HTML special characters during processing +func (b *BeegoHTTPRequest) SetEscapeHTML(isEscape bool) *BeegoHTTPRequest { + b.setting.EscapeHTML = isEscape + return b +} + // Param adds query param in to request. // params build query string as ?key1=value1&key2=value2... func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { @@ -297,70 +274,90 @@ func (b *BeegoHTTPRequest) Param(key, value string) *BeegoHTTPRequest { return b } -// PostFile add a post file to the request +// PostFile adds a post file to the request func (b *BeegoHTTPRequest) PostFile(formname, filename string) *BeegoHTTPRequest { b.files[formname] = filename return b } // Body adds request raw body. -// it supports string and []byte. +// Supports string and []byte. +// TODO return error if data is invalid func (b *BeegoHTTPRequest) Body(data interface{}) *BeegoHTTPRequest { switch t := data.(type) { case string: - bf := bytes.NewBufferString(t) - b.req.Body = ioutil.NopCloser(bf) - b.req.ContentLength = int64(len(t)) + b.reqBody([]byte(t)) case []byte: - bf := bytes.NewBuffer(t) - b.req.Body = ioutil.NopCloser(bf) - b.req.ContentLength = int64(len(t)) + b.reqBody(t) + default: + logs.Error("%+v", berror.Errorf(UnsupportedBodyType, "unsupported body data type: %s", t)) + } + return b +} + +func (b *BeegoHTTPRequest) reqBody(data []byte) *BeegoHTTPRequest { + body := io.NopCloser(bytes.NewReader(data)) + b.req.Body = body + b.req.GetBody = func() (io.ReadCloser, error) { + return body, nil + } + b.req.ContentLength = int64(len(data)) + b.copyBody = func() io.ReadCloser { + return io.NopCloser(bytes.NewReader(data)) } return b } -// XMLBody adds request raw body encoding by XML. +// XMLBody adds the request raw body encoded in XML. func (b *BeegoHTTPRequest) XMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := xml.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidXMLBody, "obj could not be converted to XML data") } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/xml") + b.reqBody(byts) + b.req.Header.Set(contentTypeKey, "application/xml") } return b, nil } -// YAMLBody adds request raw body encoding by YAML. +// YAMLBody adds the request raw body encoded in YAML. func (b *BeegoHTTPRequest) YAMLBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { byts, err := yaml.Marshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidYAMLBody, "obj could not be converted to YAML data") } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/x+yaml") + b.reqBody(byts) + b.req.Header.Set(contentTypeKey, "application/x+yaml") } return b, nil } -// JSONBody adds request raw body encoding by JSON. +// JSONBody adds the request raw body encoded in JSON. func (b *BeegoHTTPRequest) JSONBody(obj interface{}) (*BeegoHTTPRequest, error) { if b.req.Body == nil && obj != nil { - byts, err := json.Marshal(obj) + byts, err := b.JSONMarshal(obj) if err != nil { - return b, err + return b, berror.Wrap(err, InvalidJSONBody, "obj could not be converted to JSON body") } - b.req.Body = ioutil.NopCloser(bytes.NewReader(byts)) - b.req.ContentLength = int64(len(byts)) - b.req.Header.Set("Content-Type", "application/json") + b.reqBody(byts) + b.req.Header.Set(contentTypeKey, "application/json") } return b, nil } +func (b *BeegoHTTPRequest) JSONMarshal(obj interface{}) ([]byte, error) { + bf := bytes.NewBuffer([]byte{}) + jsonEncoder := json.NewEncoder(bf) + jsonEncoder.SetEscapeHTML(b.setting.EscapeHTML) + err := jsonEncoder.Encode(obj) + if err != nil { + return nil, err + } + return bf.Bytes(), nil +} + func (b *BeegoHTTPRequest) buildURL(paramBody string) { // build GET url with query string if b.req.Method == "GET" && len(paramBody) > 0 { @@ -376,46 +373,60 @@ func (b *BeegoHTTPRequest) buildURL(paramBody string) { if (b.req.Method == "POST" || b.req.Method == "PUT" || b.req.Method == "PATCH" || b.req.Method == "DELETE") && b.req.Body == nil { // with files if len(b.files) > 0 { - pr, pw := io.Pipe() - bodyWriter := multipart.NewWriter(pw) - go func() { - for formname, filename := range b.files { - fileWriter, err := bodyWriter.CreateFormFile(formname, filename) - if err != nil { - log.Println("Httplib:", err) - } - fh, err := os.Open(filename) - if err != nil { - log.Println("Httplib:", err) - } - //iocopy - _, err = io.Copy(fileWriter, fh) - fh.Close() - if err != nil { - log.Println("Httplib:", err) - } - } - for k, v := range b.params { - for _, vv := range v { - bodyWriter.WriteField(k, vv) - } - } - bodyWriter.Close() - pw.Close() - }() - b.Header("Content-Type", bodyWriter.FormDataContentType()) - b.req.Body = ioutil.NopCloser(pr) + b.handleFiles() return } // with params if len(paramBody) > 0 { - b.Header("Content-Type", "application/x-www-form-urlencoded") + b.Header(contentTypeKey, "application/x-www-form-urlencoded") b.Body(paramBody) } } } +func (b *BeegoHTTPRequest) handleFiles() { + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + go func() { + for formname, filename := range b.files { + b.handleFileToBody(bodyWriter, formname, filename) + } + for k, v := range b.params { + for _, vv := range v { + _ = bodyWriter.WriteField(k, vv) + } + } + _ = bodyWriter.Close() + _ = pw.Close() + }() + b.Header(contentTypeKey, bodyWriter.FormDataContentType()) + b.req.Body = io.NopCloser(pr) + b.Header("Transfer-Encoding", "chunked") +} + +func (*BeegoHTTPRequest) handleFileToBody(bodyWriter *multipart.Writer, formname string, filename string) { + fileWriter, err := bodyWriter.CreateFormFile(formname, filename) + const errFmt = "Httplib: %+v" + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CreateFormFileFailed, + "could not create form file, formname: %s, filename: %s", formname, filename)) + } + fh, err := os.Open(filename) + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, ReadFileFailed, "could not open this file %s", filename)) + } + // iocopy + _, err = io.Copy(fileWriter, fh) + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CopyFileFailed, "could not copy this file %s", filename)) + } + err = fh.Close() + if err != nil { + logs.Error(errFmt, berror.Wrapf(err, CloseFileFailed, "could not close this file %s", filename)) + } +} + func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { if b.resp.StatusCode != 0 { return b.resp, nil @@ -428,63 +439,42 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { return resp, nil } -// DoRequest will do the client.Do +// DoRequest executes client.Do func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { - var paramBody string - if len(b.params) > 0 { - var buf bytes.Buffer - for k, v := range b.params { - for _, vv := range v { - buf.WriteString(url.QueryEscape(k)) - buf.WriteByte('=') - buf.WriteString(url.QueryEscape(vv)) - buf.WriteByte('&') - } + root := doRequestFilter + if len(b.setting.FilterChains) > 0 { + for i := len(b.setting.FilterChains) - 1; i >= 0; i-- { + root = b.setting.FilterChains[i](root) } - paramBody = buf.String() - paramBody = paramBody[0 : len(paramBody)-1] } + return root(b.req.Context(), b) +} + +// Deprecated: please use NewBeegoRequestWithContext +func (b *BeegoHTTPRequest) DoRequestWithCtx(ctx context.Context) (resp *http.Response, err error) { + root := doRequestFilter + if len(b.setting.FilterChains) > 0 { + for i := len(b.setting.FilterChains) - 1; i >= 0; i-- { + root = b.setting.FilterChains[i](root) + } + } + return root(ctx, b) +} + +func (b *BeegoHTTPRequest) doRequest(_ context.Context) (*http.Response, error) { + paramBody := b.buildParamBody() b.buildURL(paramBody) urlParsed, err := url.Parse(b.url) if err != nil { - return nil, err + return nil, berror.Wrapf(err, InvalidUrl, "parse url failed, the url is %s", b.url) } b.req.URL = urlParsed - trans := b.setting.Transport + trans := b.buildTrans() - if trans == nil { - // create default transport - trans = &http.Transport{ - TLSClientConfig: b.setting.TLSClientConfig, - Proxy: b.setting.Proxy, - Dial: TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), - MaxIdleConnsPerHost: 100, - } - } else { - // if b.transport is *http.Transport then set the settings. - if t, ok := trans.(*http.Transport); ok { - if t.TLSClientConfig == nil { - t.TLSClientConfig = b.setting.TLSClientConfig - } - if t.Proxy == nil { - t.Proxy = b.setting.Proxy - } - if t.Dial == nil { - t.Dial = TimeoutDialer(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) - } - } - } - - var jar http.CookieJar - if b.setting.EnableCookie { - if defaultCookieJar == nil { - createDefaultCookie() - } - jar = defaultCookieJar - } + jar := b.buildCookieJar() client := &http.Client{ Transport: trans, @@ -499,27 +489,82 @@ func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) { client.CheckRedirect = b.setting.CheckRedirect } - if b.setting.ShowDebug { - dump, err := httputil.DumpRequest(b.req, b.setting.DumpBody) - if err != nil { - log.Println(err.Error()) - } - b.dump = dump - } + return b.sendRequest(client) +} + +func (b *BeegoHTTPRequest) sendRequest(client *http.Client) (resp *http.Response, err error) { // retries default value is 0, it will run once. // retries equal to -1, it will run forever until success - // retries is setted, it will retries fixed times. + // retries is set, it will retry fixed times. + // Sleeps for a 400ms between calls to reduce spam for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ { resp, err = client.Do(b.req) if err == nil { - break + return + } + time.Sleep(b.setting.RetryDelay) + b.req.Body = b.copyBody() + } + return nil, berror.Wrap(err, SendRequestFailed, "sending request fail") +} + +func (b *BeegoHTTPRequest) buildCookieJar() http.CookieJar { + var jar http.CookieJar + if b.setting.EnableCookie { + if defaultCookieJar == nil { + createDefaultCookie() + } + jar = defaultCookieJar + } + return jar +} + +func (b *BeegoHTTPRequest) buildTrans() http.RoundTripper { + trans := b.setting.Transport + + if trans == nil { + // create default transport + trans = &http.Transport{ + TLSClientConfig: b.setting.TLSClientConfig, + Proxy: b.setting.Proxy, + DialContext: TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout), + MaxIdleConnsPerHost: 100, + } + } else if t, ok := trans.(*http.Transport); ok { + // if b.transport is *http.Transport then set the settings. + if t.TLSClientConfig == nil { + t.TLSClientConfig = b.setting.TLSClientConfig + } + if t.Proxy == nil { + t.Proxy = b.setting.Proxy + } + if t.DialContext == nil { + t.DialContext = TimeoutDialerCtx(b.setting.ConnectTimeout, b.setting.ReadWriteTimeout) + } + } + return trans +} + +func (b *BeegoHTTPRequest) buildParamBody() string { + var paramBody string + if len(b.params) > 0 { + var buf bytes.Buffer + for k, v := range b.params { + for _, vv := range v { + buf.WriteString(url.QueryEscape(k)) + buf.WriteByte('=') + buf.WriteString(url.QueryEscape(vv)) + buf.WriteByte('&') + } } + paramBody = buf.String() + paramBody = paramBody[0 : len(paramBody)-1] } - return resp, err + return paramBody } // String returns the body string in response. -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) String() (string, error) { data, err := b.Bytes() if err != nil { @@ -530,7 +575,7 @@ func (b *BeegoHTTPRequest) String() (string, error) { } // Bytes returns the body []byte in response. -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { if b.body != nil { return b.body, nil @@ -546,24 +591,18 @@ func (b *BeegoHTTPRequest) Bytes() ([]byte, error) { if b.setting.Gzip && resp.Header.Get("Content-Encoding") == "gzip" { reader, err := gzip.NewReader(resp.Body) if err != nil { - return nil, err + return nil, berror.Wrap(err, ReadGzipBodyFailed, "building gzip reader failed") } - b.body, err = ioutil.ReadAll(reader) - return b.body, err + b.body, err = io.ReadAll(reader) + return b.body, berror.Wrap(err, ReadGzipBodyFailed, "reading gzip data failed") } - b.body, err = ioutil.ReadAll(resp.Body) + b.body, err = io.ReadAll(resp.Body) return b.body, err } // ToFile saves the body data in response to one file. -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) ToFile(filename string) error { - f, err := os.Create(filename) - if err != nil { - return err - } - defer f.Close() - resp, err := b.getResponse() if err != nil { return err @@ -572,48 +611,120 @@ func (b *BeegoHTTPRequest) ToFile(filename string) error { return nil } defer resp.Body.Close() + err = pathExistAndMkdir(filename) + if err != nil { + return err + } + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() _, err = io.Copy(f, resp.Body) return err } -// ToJSON returns the map that marshals from the body bytes as json in response . -// it calls Response inner. +// Check if the file directory exists. If it doesn't then it's created +func pathExistAndMkdir(filename string) (err error) { + filename = filepath.Dir(filename) + _, err = os.Stat(filename) + if err == nil { + return nil + } + if os.IsNotExist(err) { + err = os.MkdirAll(filename, os.ModePerm) + if err == nil { + return nil + } + } + return berror.Wrapf(err, CreateFileIfNotExistFailed, "try to create(if not exist) failed: %s", filename) +} + +// ToJSON returns the map that marshals from the body bytes as json in response. +// Calls Response inner. func (b *BeegoHTTPRequest) ToJSON(v interface{}) error { data, err := b.Bytes() if err != nil { return err } - return json.Unmarshal(data, v) + return berror.Wrap(json.Unmarshal(data, v), + UnmarshalJSONResponseToObjectFailed, "unmarshal json body to object failed.") } // ToXML returns the map that marshals from the body bytes as xml in response . -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) ToXML(v interface{}) error { data, err := b.Bytes() if err != nil { return err } - return xml.Unmarshal(data, v) + return berror.Wrap(xml.Unmarshal(data, v), + UnmarshalXMLResponseToObjectFailed, "unmarshal xml body to object failed.") } // ToYAML returns the map that marshals from the body bytes as yaml in response . -// it calls Response inner. +// Calls Response inner. func (b *BeegoHTTPRequest) ToYAML(v interface{}) error { data, err := b.Bytes() if err != nil { return err } - return yaml.Unmarshal(data, v) + return berror.Wrap(yaml.Unmarshal(data, v), + UnmarshalYAMLResponseToObjectFailed, "unmarshal yaml body to object failed.") +} + +// ToValue attempts to resolve the response body to value using an existing method. +// Calls Response inner. +// If response header contain Content-Type, func will call ToJSON\ToXML\ToYAML. +// Else it will try to parse body as json\yaml\xml, If all attempts fail, an error will be returned +func (b *BeegoHTTPRequest) ToValue(value interface{}) error { + if value == nil { + return nil + } + + contentType := strings.Split(b.resp.Header.Get(contentTypeKey), ";")[0] + // try to parse it as content type + switch contentType { + case "application/json": + return b.ToJSON(value) + case "text/xml", "application/xml": + return b.ToXML(value) + case "text/yaml", "application/x-yaml", "application/x+yaml": + return b.ToYAML(value) + } + + // try to parse it anyway + if err := b.ToJSON(value); err == nil { + return nil + } + if err := b.ToYAML(value); err == nil { + return nil + } + if err := b.ToXML(value); err == nil { + return nil + } + + return berror.Error(UnmarshalResponseToObjectFailed, "unmarshal body to object failed.") } -// Response executes request client gets response mannually. +// Response executes request client gets response manually. func (b *BeegoHTTPRequest) Response() (*http.Response, error) { return b.getResponse() } // TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field. +// Deprecated +// we will move this at the end of 2021 +// please use TimeoutDialerCtx func TimeoutDialer(cTimeout time.Duration, rwTimeout time.Duration) func(net, addr string) (c net.Conn, err error) { return func(netw, addr string) (net.Conn, error) { + return TimeoutDialerCtx(cTimeout, rwTimeout)(context.Background(), netw, addr) + } +} + +func TimeoutDialerCtx(cTimeout time.Duration, + rwTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) { + return func(ctx context.Context, netw, addr string) (net.Conn, error) { conn, err := net.DialTimeout(netw, addr, cTimeout) if err != nil { return nil, err diff --git a/client/httplib/httplib_test.go b/client/httplib/httplib_test.go new file mode 100644 index 0000000000..dd5710580d --- /dev/null +++ b/client/httplib/httplib_test.go @@ -0,0 +1,509 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "bytes" + "context" + json "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type HttplibTestSuite struct { + suite.Suite + l net.Listener +} + +func (h *HttplibTestSuite) SetupSuite() { + listener, err := net.Listen("tcp", ":8080") + require.NoError(h.T(), err) + h.l = listener + + handler := http.NewServeMux() + + handler.HandleFunc("/get", func(writer http.ResponseWriter, request *http.Request) { + agent := request.Header.Get("User-Agent") + _, _ = writer.Write([]byte("hello, " + agent)) + }) + + handler.HandleFunc("/put", func(writer http.ResponseWriter, request *http.Request) { + _, _ = writer.Write([]byte("hello, put")) + }) + + handler.HandleFunc("/post", func(writer http.ResponseWriter, request *http.Request) { + body, _ := io.ReadAll(request.Body) + _, _ = writer.Write(body) + }) + + handler.HandleFunc("/delete", func(writer http.ResponseWriter, request *http.Request) { + _, _ = writer.Write([]byte("hello, delete")) + }) + + handler.HandleFunc("/cookies/set", func(writer http.ResponseWriter, request *http.Request) { + k1 := request.URL.Query().Get("k1") + http.SetCookie(writer, &http.Cookie{ + Name: "k1", + Value: k1, + }) + _, _ = writer.Write([]byte("hello, set cookie")) + }) + + handler.HandleFunc("/cookies", func(writer http.ResponseWriter, request *http.Request) { + body := request.Cookies()[0].String() + _, _ = writer.Write([]byte(body)) + }) + + handler.HandleFunc("/basic-auth/user/passwd", func(writer http.ResponseWriter, request *http.Request) { + _, _, ok := request.BasicAuth() + if ok { + _, _ = writer.Write([]byte("authenticated")) + } else { + _, _ = writer.Write([]byte("no auth")) + } + }) + + handler.HandleFunc("/headers", func(writer http.ResponseWriter, request *http.Request) { + agent := request.Header.Get("User-Agent") + _, _ = writer.Write([]byte(agent)) + }) + + handler.HandleFunc("/ip", func(writer http.ResponseWriter, request *http.Request) { + data := map[string]string{"origin": "127.0.0.1"} + jsonBytes, _ := json.Marshal(data) + _, _ = writer.Write(jsonBytes) + }) + + handler.HandleFunc("/redirect", func(writer http.ResponseWriter, request *http.Request) { + http.Redirect(writer, request, "redirect_dst", http.StatusTemporaryRedirect) + }) + handler.HandleFunc("/redirect_dst", func(writer http.ResponseWriter, request *http.Request) { + _, _ = writer.Write([]byte("hello")) + }) + + handler.HandleFunc("/retry", func(writer http.ResponseWriter, request *http.Request) { + body, err := io.ReadAll(request.Body) + require.NoError(h.T(), err) + assert.Equal(h.T(), []byte("retry body"), body) + panic("mock error") + }) + + go func() { + _ = http.Serve(listener, handler) + }() +} + +func (h *HttplibTestSuite) TearDownSuite() { + _ = h.l.Close() +} + +func TestHttplib(t *testing.T) { + suite.Run(t, &HttplibTestSuite{}) +} + +func (h *HttplibTestSuite) TestResponse() { + req := Get("http://localhost:8080/get") + _, err := req.Response() + require.NoError(h.T(), err) +} + +func (h *HttplibTestSuite) TestDoRequest() { + t := h.T() + req := Get("http://localhost:8080/redirect") + retryAmount := 1 + req.Retries(1) + req.RetryDelay(1400 * time.Millisecond) + retryDelay := 1400 * time.Millisecond + + req.setting.CheckRedirect = func(redirectReq *http.Request, redirectVia []*http.Request) error { + return errors.New("Redirect triggered") + } + + startTime := time.Now().UnixNano() / int64(time.Millisecond) + + _, err := req.Response() + require.Error(t, err) + + endTime := time.Now().UnixNano() / int64(time.Millisecond) + elapsedTime := endTime - startTime + delayedTime := int64(retryAmount) * retryDelay.Milliseconds() + + if elapsedTime < delayedTime { + t.Errorf("Not enough retries. Took %dms. Delay was meant to take %dms", elapsedTime, delayedTime) + } +} + +func (h *HttplibTestSuite) TestGet() { + t := h.T() + req := Get("http://localhost:8080/get") + b, err := req.Bytes() + require.NoError(t, err) + + s, err := req.String() + require.NoError(t, err) + require.Equal(t, string(b), s) +} + +func (h *HttplibTestSuite) TestSimplePost() { + t := h.T() + v := "smallfish" + req := Post("http://localhost:8080/post") + req.Param("username", v) + + str, err := req.String() + require.NoError(t, err) + n := strings.Index(str, v) + require.NotEqual(t, -1, n) +} + +func (h *HttplibTestSuite) TestSimplePut() { + t := h.T() + _, err := Put("http://localhost:8080/put").String() + require.NoError(t, err) +} + +func (h *HttplibTestSuite) TestSimpleDelete() { + t := h.T() + _, err := Delete("http://localhost:8080/delete").String() + require.NoError(t, err) +} + +func (h *HttplibTestSuite) TestSimpleDeleteParam() { + t := h.T() + _, err := Delete("http://localhost:8080/delete").Param("key", "val").String() + require.NoError(t, err) +} + +func (h *HttplibTestSuite) TestWithCookie() { + t := h.T() + v := "smallfish" + _, err := Get("http://localhost:8080/cookies/set?k1=" + v).SetEnableCookie(true).String() + require.NoError(t, err) + + str, err := Get("http://localhost:8080/cookies").SetEnableCookie(true).String() + require.NoError(t, err) + + n := strings.Index(str, v) + require.NotEqual(t, -1, n) +} + +func (h *HttplibTestSuite) TestWithBasicAuth() { + t := h.T() + str, err := Get("http://localhost:8080/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() + require.NoError(t, err) + n := strings.Index(str, "authenticated") + require.NotEqual(t, -1, n) +} + +func (h *HttplibTestSuite) TestWithUserAgent() { + t := h.T() + v := "beego" + str, err := Get("http://localhost:8080/headers").SetUserAgent(v).String() + require.NoError(t, err) + n := strings.Index(str, v) + require.NotEqual(t, -1, n) +} + +func (h *HttplibTestSuite) TestWithSetting() { + t := h.T() + v := "beego" + var setting BeegoHTTPSettings + setting.EnableCookie = true + setting.UserAgent = v + setting.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 50, + IdleConnTimeout: 90 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + setting.ReadWriteTimeout = 5 * time.Second + SetDefaultSetting(setting) + + str, err := Get("http://localhost:8080/get").String() + require.NoError(t, err) + n := strings.Index(str, v) + require.NotEqual(t, -1, n) +} + +func (h *HttplibTestSuite) TestToJson() { + t := h.T() + req := Get("http://localhost:8080/ip") + resp, err := req.Response() + require.NoError(t, err) + t.Log(resp) + + type IP struct { + Origin string `json:"origin"` + } + var ip IP + err = req.ToJSON(&ip) + require.NoError(t, err) + require.Equal(t, "127.0.0.1", ip.Origin) + + ips := strings.Split(ip.Origin, ",") + require.NotEmpty(t, ips) +} + +func (h *HttplibTestSuite) TestToFile() { + t := h.T() + f := "beego_testfile" + req := Get("http://localhost:8080/ip") + err := req.ToFile(f) + require.NoError(t, err) + defer os.Remove(f) + + b, err := os.ReadFile(f) + n := bytes.Index(b, []byte("origin")) + require.NotEqual(t, -1, n) +} + +func (h *HttplibTestSuite) TestToFileDir() { + t := h.T() + f := "./files/beego_testfile" + req := Get("http://localhost:8080/ip") + err := req.ToFile(f) + require.NoError(t, err) + defer os.RemoveAll("./files") + b, err := os.ReadFile(f) + require.NoError(t, err) + n := bytes.Index(b, []byte("origin")) + require.NotEqual(t, -1, n) +} + +func (h *HttplibTestSuite) TestHeader() { + t := h.T() + req := Get("http://localhost:8080/headers") + req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") + _, err := req.String() + require.NoError(t, err) +} + +// TestAddFilter make sure that AddFilters only work for the specific request +func (h *HttplibTestSuite) TestAddFilter() { + t := h.T() + req := Get("http://beego.vip") + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return next(ctx, req) + } + }) + + r := Get("http://beego.vip") + assert.Equal(t, 1, len(req.setting.FilterChains)-len(r.setting.FilterChains)) +} + +func (h *HttplibTestSuite) TestFilterChainOrder() { + t := h.T() + req := Get("http://beego.vip") + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("first"), nil + } + }) + + req.AddFilters(func(next Filter) Filter { + return func(ctx context.Context, req *BeegoHTTPRequest) (*http.Response, error) { + return NewHttpResponseWithJsonBody("second"), nil + } + }) + + resp, err := req.DoRequestWithCtx(context.Background()) + assert.Nil(t, err) + data := make([]byte, 5) + _, _ = resp.Body.Read(data) + assert.Equal(t, "first", string(data)) +} + +func (h *HttplibTestSuite) TestHead() { + t := h.T() + req := Head("http://beego.vip") + assert.NotNil(t, req) + assert.Equal(t, "HEAD", req.req.Method) +} + +func (h *HttplibTestSuite) TestDelete() { + t := h.T() + req := Delete("http://beego.vip") + assert.NotNil(t, req) + assert.Equal(t, "DELETE", req.req.Method) +} + +func (h *HttplibTestSuite) TestPost() { + t := h.T() + req := Post("http://beego.vip") + assert.NotNil(t, req) + assert.Equal(t, "POST", req.req.Method) +} + +func (h *HttplibTestSuite) TestPut() { + t := h.T() + req := Put("http://beego.vip") + assert.NotNil(t, req) + assert.Equal(t, "PUT", req.req.Method) +} + +func (h *HttplibTestSuite) TestRetry() { + defaultSetting.Retries = 2 + testCases := []struct { + name string + req func(t *testing.T) *BeegoHTTPRequest + wantErr error + }{ + { + name: "retry_failed", + req: func(t *testing.T) *BeegoHTTPRequest { + req := NewBeegoRequest("http://localhost:8080/retry", http.MethodPost) + req.Body("retry body") + return req + }, + wantErr: io.EOF, + }, + } + + for _, tc := range testCases { + h.T().Run(tc.name, func(t *testing.T) { + req := tc.req(t) + resp, err := req.DoRequest() + assert.ErrorIs(t, err, tc.wantErr) + assert.Nil(t, resp) + }) + } +} + +func TestNewBeegoRequest(t *testing.T) { + req := NewBeegoRequest("http://beego.vip", "GET") + assert.NotNil(t, req) + assert.Equal(t, "GET", req.req.Method) + + // invalid case but still go request + req = NewBeegoRequest("httpa\ta://beego.vip", "GET") + assert.NotNil(t, req) +} + +func TestNewBeegoRequestWithCtx(t *testing.T) { + req := NewBeegoRequestWithCtx(context.Background(), "http://beego.vip", "GET") + assert.NotNil(t, req) + assert.Equal(t, "GET", req.req.Method) + + // bad url but still get request + req = NewBeegoRequestWithCtx(context.Background(), "httpa\ta://beego.vip", "GET") + assert.NotNil(t, req) + + // bad method but still get request + req = NewBeegoRequestWithCtx(context.Background(), "http://beego.vip", "G\tET") + assert.NotNil(t, req) + assert.NotNil(t, req.copyBody) +} + +func TestBeegoHTTPRequestSetProtocolVersion(t *testing.T) { + req := NewBeegoRequest("http://beego.vip", "GET") + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) + + req.SetProtocolVersion("") + assert.Equal(t, "HTTP/1.1", req.req.Proto) + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) + + // invalid case + req.SetProtocolVersion("HTTP/aaa1.1") + assert.Equal(t, "HTTP/1.1", req.req.Proto) + assert.Equal(t, 1, req.req.ProtoMajor) + assert.Equal(t, 1, req.req.ProtoMinor) +} + +func TestBeegoHTTPRequestHeader(t *testing.T) { + req := Post("http://beego.vip") + key, value := "test-header", "test-header-value" + req.Header(key, value) + assert.Equal(t, value, req.req.Header.Get(key)) +} + +func TestBeegoHTTPRequestSetHost(t *testing.T) { + req := Post("http://beego.vip") + host := "test-hose" + req.SetHost(host) + assert.Equal(t, host, req.req.Host) +} + +func TestBeegoHTTPRequestParam(t *testing.T) { + req := Post("http://beego.vip") + key, value := "test-param", "test-param-value" + req.Param(key, value) + assert.Equal(t, value, req.params[key][0]) + + value1 := "test-param-value-1" + req.Param(key, value1) + assert.Equal(t, value1, req.params[key][1]) +} + +func TestBeegoHTTPRequestBody(t *testing.T) { + req := Post("http://beego.vip") + body := `hello, world` + req.Body([]byte(body)) + assert.Equal(t, int64(len(body)), req.req.ContentLength) + assert.NotNil(t, req.req.GetBody) + assert.NotNil(t, req.req.Body) + + body = "hhhh, I am test" + req.Body(body) + assert.Equal(t, int64(len(body)), req.req.ContentLength) + assert.NotNil(t, req.req.GetBody) + assert.NotNil(t, req.req.Body) + + // invalid case + req.Body(13) +} + +type user struct { + Name string `xml:"name"` +} + +func TestBeegoHTTPRequestXMLBody(t *testing.T) { + req := Post("http://beego.vip") + body := &user{ + Name: "Tom", + } + _, err := req.XMLBody(body) + assert.True(t, req.req.ContentLength > 0) + assert.Nil(t, err) + assert.NotNil(t, req.req.GetBody) +} + +func TestBeegoHTTPRequestJSONMarshal(t *testing.T) { + req := Post("http://beego.vip") + req.SetEscapeHTML(false) + body := map[string]interface{}{ + "escape": "left&right", + } + b, _ := req.JSONMarshal(body) + assert.Equal(t, fmt.Sprintf(`{"escape":"left&right"}%s`, "\n"), string(b)) +} diff --git a/client/httplib/mock/mock.go b/client/httplib/mock/mock.go new file mode 100644 index 0000000000..421a7a45f2 --- /dev/null +++ b/client/httplib/mock/mock.go @@ -0,0 +1,78 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "net/http" + + "github.com/beego/beego/v2/client/httplib" + "github.com/beego/beego/v2/core/logs" +) + +const mockCtxKey = "beego-httplib-mock" + +func init() { + InitMockSetting() +} + +type Stub interface { + Mock(cond RequestCondition, resp *http.Response, err error) + Clear() + MockByPath(path string, resp *http.Response, err error) +} + +var mockFilter = &MockResponseFilter{} + +func InitMockSetting() { + httplib.AddDefaultFilter(mockFilter.FilterChain) +} + +func StartMock() Stub { + return mockFilter +} + +func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context { + return context.WithValue(ctx, mockCtxKey, mock) +} + +func mockFromCtx(ctx context.Context) []*Mock { + ms := ctx.Value(mockCtxKey) + if ms != nil { + if res, ok := ms.([]*Mock); ok { + return res + } + logs.Error("mockCtxKey found in context, but value is not type []*Mock") + } + return nil +} + +type Mock struct { + cond RequestCondition + resp *http.Response + err error +} + +func NewMockByPath(path string, resp *http.Response, err error) *Mock { + return NewMock(NewSimpleCondition(path), resp, err) +} + +func NewMock(con RequestCondition, resp *http.Response, err error) *Mock { + return &Mock{ + cond: con, + resp: resp, + err: err, + } +} diff --git a/client/httplib/mock/mock_condition.go b/client/httplib/mock/mock_condition.go new file mode 100644 index 0000000000..912699ff47 --- /dev/null +++ b/client/httplib/mock/mock_condition.go @@ -0,0 +1,176 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "encoding/json" + "net/textproto" + "regexp" + + "github.com/beego/beego/v2/client/httplib" +) + +type RequestCondition interface { + Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool +} + +// reqCondition create condition +// - path: same path +// - pathReg: request path match pathReg +// - method: same method +// - Query parameters (key, value) +// - header (key, value) +// - Body json format, contains specific (key, value). +type SimpleCondition struct { + pathReg string + path string + method string + query map[string]string + header map[string]string + body map[string]interface{} +} + +func NewSimpleCondition(path string, opts ...simpleConditionOption) *SimpleCondition { + sc := &SimpleCondition{ + path: path, + query: make(map[string]string), + header: make(map[string]string), + body: map[string]interface{}{}, + } + for _, o := range opts { + o(sc) + } + return sc +} + +func (sc *SimpleCondition) Match(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + var res bool + if len(sc.path) > 0 { + res = sc.matchPath(ctx, req) + } else if len(sc.pathReg) > 0 { + res = sc.matchPathReg(ctx, req) + } else { + return false + } + return res && + sc.matchMethod(ctx, req) && + sc.matchQuery(ctx, req) && + sc.matchHeader(ctx, req) && + sc.matchBodyFields(ctx, req) +} + +func (sc *SimpleCondition) matchPath(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + path := req.GetRequest().URL.Path + return path == sc.path +} + +func (sc *SimpleCondition) matchPathReg(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + path := req.GetRequest().URL.Path + if b, err := regexp.Match(sc.pathReg, []byte(path)); err == nil { + return b + } + return false +} + +func (sc *SimpleCondition) matchQuery(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + qs := req.GetRequest().URL.Query() + for k, v := range sc.query { + if uv, ok := qs[k]; !ok || uv[0] != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchHeader(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + headers := req.GetRequest().Header + for k, v := range sc.header { + if uv, ok := headers[k]; !ok || uv[0] != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchBodyFields(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + if len(sc.body) == 0 { + return true + } + getBody := req.GetRequest().GetBody + body, err := getBody() + if err != nil { + return false + } + bytes := make([]byte, req.GetRequest().ContentLength) + _, err = body.Read(bytes) + if err != nil { + return false + } + + m := make(map[string]interface{}) + + err = json.Unmarshal(bytes, &m) + + if err != nil { + return false + } + + for k, v := range sc.body { + if uv, ok := m[k]; !ok || uv != v { + return false + } + } + return true +} + +func (sc *SimpleCondition) matchMethod(ctx context.Context, req *httplib.BeegoHTTPRequest) bool { + if len(sc.method) > 0 { + return sc.method == req.GetRequest().Method + } + return true +} + +type simpleConditionOption func(sc *SimpleCondition) + +func WithPathReg(pathReg string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.pathReg = pathReg + } +} + +func WithQuery(key, value string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.query[key] = value + } +} + +func WithHeader(key, value string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.header[textproto.CanonicalMIMEHeaderKey(key)] = value + } +} + +func WithJsonBodyFields(field string, value interface{}) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.body[field] = value + } +} + +func WithMethod(method string) simpleConditionOption { + return func(sc *SimpleCondition) { + sc.method = method + } +} diff --git a/client/httplib/mock/mock_condition_test.go b/client/httplib/mock/mock_condition_test.go new file mode 100644 index 0000000000..79bc7ad86c --- /dev/null +++ b/client/httplib/mock/mock_condition_test.go @@ -0,0 +1,119 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestSimpleConditionMatchPath(t *testing.T) { + sc := NewSimpleCondition("/abc/s") + res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s")) + assert.True(t, res) +} + +func TestSimpleConditionMatchQuery(t *testing.T) { + k, v := "my-key", "my-value" + sc := NewSimpleCondition("/abc/s") + res := sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value")) + assert.True(t, res) + + sc = NewSimpleCondition("/abc/s", WithQuery(k, v)) + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value")) + assert.True(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-valuesss")) + assert.False(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key-a=my-value")) + assert.False(t, res) + + res = sc.Match(context.Background(), httplib.Get("http://localhost:8080/abc/s?my-key=my-value&abc=hello")) + assert.True(t, res) +} + +func TestSimpleConditionMatchHeader(t *testing.T) { + k, v := "my-header", "my-header-value" + sc := NewSimpleCondition("/abc/s") + req := httplib.Get("http://localhost:8080/abc/s") + assert.True(t, sc.Match(context.Background(), req)) + + req = httplib.Get("http://localhost:8080/abc/s") + req.Header(k, v) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithHeader(k, v)) + req.Header(k, v) + assert.True(t, sc.Match(context.Background(), req)) + + req.Header(k, "invalid") + assert.False(t, sc.Match(context.Background(), req)) +} + +func TestSimpleConditionMatchBodyField(t *testing.T) { + sc := NewSimpleCondition("/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") + + assert.True(t, sc.Match(context.Background(), req)) + + req.Body(`{ + "body-field": 123 +}`) + assert.True(t, sc.Match(context.Background(), req)) + + k := "body-field" + v := float64(123) + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields(k, v)) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields(k, v)) + req.Body(`{ + "body-field": abc +}`) + assert.False(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithJsonBodyFields("body-field", "abc")) + req.Body(`{ + "body-field": "abc" +}`) + assert.True(t, sc.Match(context.Background(), req)) +} + +func TestSimpleConditionMatch(t *testing.T) { + sc := NewSimpleCondition("/abc/s") + req := httplib.Post("http://localhost:8080/abc/s") + + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithMethod("POST")) + assert.True(t, sc.Match(context.Background(), req)) + + sc = NewSimpleCondition("/abc/s", WithMethod("GET")) + assert.False(t, sc.Match(context.Background(), req)) +} + +func TestSimpleConditionMatchPathReg(t *testing.T) { + sc := NewSimpleCondition("", WithPathReg(`\/abc\/.*`)) + req := httplib.Post("http://localhost:8080/abc/s") + assert.True(t, sc.Match(context.Background(), req)) + + req = httplib.Post("http://localhost:8080/abcd/s") + assert.False(t, sc.Match(context.Background(), req)) +} diff --git a/client/httplib/mock/mock_filter.go b/client/httplib/mock/mock_filter.go new file mode 100644 index 0000000000..225d65f38f --- /dev/null +++ b/client/httplib/mock/mock_filter.go @@ -0,0 +1,61 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "net/http" + + "github.com/beego/beego/v2/client/httplib" +) + +// MockResponse will return mock response if find any suitable mock data +// if you want to test your code using httplib, you need this. +type MockResponseFilter struct { + ms []*Mock +} + +func NewMockResponseFilter() *MockResponseFilter { + return &MockResponseFilter{ + ms: make([]*Mock, 0, 1), + } +} + +func (m *MockResponseFilter) FilterChain(next httplib.Filter) httplib.Filter { + return func(ctx context.Context, req *httplib.BeegoHTTPRequest) (*http.Response, error) { + ms := mockFromCtx(ctx) + ms = append(ms, m.ms...) + for _, mock := range ms { + if mock.cond.Match(ctx, req) { + return mock.resp, mock.err + } + } + return next(ctx, req) + } +} + +func (m *MockResponseFilter) MockByPath(path string, resp *http.Response, err error) { + m.Mock(NewSimpleCondition(path), resp, err) +} + +func (m *MockResponseFilter) Clear() { + m.ms = make([]*Mock, 0, 1) +} + +// Mock add mock data +// If the cond.Match(...) = true, the resp and err will be returned +func (m *MockResponseFilter) Mock(cond RequestCondition, resp *http.Response, err error) { + m.ms = append(m.ms, NewMock(cond, resp, err)) +} diff --git a/client/httplib/mock/mock_filter_test.go b/client/httplib/mock/mock_filter_test.go new file mode 100644 index 0000000000..99c2e67f76 --- /dev/null +++ b/client/httplib/mock/mock_filter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestMockResponseFilterFilterChain(t *testing.T) { + req := httplib.Get("http://localhost:8080/abc/s") + ft := NewMockResponseFilter() + + expectedResp := httplib.NewHttpResponseWithJsonBody(`{}`) + expectedErr := errors.New("expected error") + ft.Mock(NewSimpleCondition("/abc/s"), expectedResp, expectedErr) + + req.AddFilters(ft.FilterChain) + + resp, err := req.DoRequest() + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abcd/s") + req.AddFilters(ft.FilterChain) + + resp, err = req.DoRequest() + assert.NotEqual(t, expectedErr, err) + assert.NotEqual(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abc/s") + req.AddFilters(ft.FilterChain) + expectedResp1 := httplib.NewHttpResponseWithJsonBody(map[string]string{}) + expectedErr1 := errors.New("expected error") + ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1) + + resp, err = req.DoRequest() + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) + + req = httplib.Get("http://localhost:8080/abc/abs/bbc") + req.AddFilters(ft.FilterChain) + ft.Mock(NewSimpleCondition("/abc/abs/bbc"), expectedResp1, expectedErr1) + resp, err = req.DoRequest() + assert.Equal(t, expectedErr1, err) + assert.Equal(t, expectedResp1, resp) +} diff --git a/client/httplib/mock/mock_test.go b/client/httplib/mock/mock_test.go new file mode 100644 index 0000000000..ce7322dfb7 --- /dev/null +++ b/client/httplib/mock/mock_test.go @@ -0,0 +1,75 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/httplib" +) + +func TestStartMock(t *testing.T) { + // httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain} + + stub := StartMock() + // defer stub.Clear() + + expectedResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`)) + expectedErr := errors.New("expected err") + + stub.Mock(NewSimpleCondition("/abc"), expectedResp, expectedErr) + + resp, err := OriginalCodeUsingHttplib() + + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) +} + +// TestStartMock_Isolation Test StartMock that +// mock only work for this request +func TestStartMockIsolation(t *testing.T) { + // httplib.defaultSetting.FilterChains = []httplib.FilterChain{mockFilter.FilterChain} + // setup global stub + stub := StartMock() + globalMockResp := httplib.NewHttpResponseWithJsonBody([]byte(`{}`)) + globalMockErr := errors.New("expected err") + stub.Mock(NewSimpleCondition("/abc"), globalMockResp, globalMockErr) + + expectedResp := httplib.NewHttpResponseWithJsonBody(struct { + A string `json:"a"` + }{ + A: "aaa", + }) + expectedErr := errors.New("expected err aa") + m := NewMockByPath("/abc", expectedResp, expectedErr) + ctx := CtxWithMock(context.Background(), m) + + resp, err := OriginnalCodeUsingHttplibPassCtx(ctx) + assert.Equal(t, expectedErr, err) + assert.Equal(t, expectedResp, resp) +} + +func OriginnalCodeUsingHttplibPassCtx(ctx context.Context) (*http.Response, error) { + return httplib.Get("http://localhost:7777/abc").DoRequestWithCtx(ctx) +} + +func OriginalCodeUsingHttplib() (*http.Response, error) { + return httplib.Get("http://localhost:7777/abc").DoRequest() +} diff --git a/client/httplib/module.go b/client/httplib/module.go new file mode 100644 index 0000000000..8503133c08 --- /dev/null +++ b/client/httplib/module.go @@ -0,0 +1,17 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +const moduleName = "httplib" diff --git a/client/httplib/setting.go b/client/httplib/setting.go new file mode 100644 index 0000000000..3d8d195c53 --- /dev/null +++ b/client/httplib/setting.go @@ -0,0 +1,87 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httplib + +import ( + "crypto/tls" + "net/http" + "net/http/cookiejar" + "net/url" + "sync" + "time" +) + +// BeegoHTTPSettings is the http.Client setting +type BeegoHTTPSettings struct { + UserAgent string + ConnectTimeout time.Duration + ReadWriteTimeout time.Duration + TLSClientConfig *tls.Config + Proxy func(*http.Request) (*url.URL, error) + Transport http.RoundTripper + CheckRedirect func(req *http.Request, via []*http.Request) error + EnableCookie bool + Gzip bool + Retries int // if set to -1 means will retry forever + RetryDelay time.Duration + FilterChains []FilterChain + EscapeHTML bool // if set to false means will not escape escape HTML special characters during processing, default true +} + +// createDefaultCookie creates a global cookiejar to store cookies. +func createDefaultCookie() { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultCookieJar, _ = cookiejar.New(nil) +} + +// SetDefaultSetting overwrites default settings +// Keep in mind that when you invoke the SetDefaultSetting +// some methods invoked before SetDefaultSetting +func SetDefaultSetting(setting BeegoHTTPSettings) { + settingMutex.Lock() + defer settingMutex.Unlock() + defaultSetting = setting +} + +// GetDefaultSetting return current default setting +func GetDefaultSetting() BeegoHTTPSettings { + return defaultSetting +} + +var defaultSetting = BeegoHTTPSettings{ + UserAgent: "beegoServer", + ConnectTimeout: 60 * time.Second, + ReadWriteTimeout: 60 * time.Second, + Gzip: true, + FilterChains: make([]FilterChain, 0, 4), + EscapeHTML: true, +} + +var ( + defaultCookieJar http.CookieJar + settingMutex sync.Mutex +) + +// AddDefaultFilter add a new filter into defaultSetting +// Be careful about using this method if you invoke SetDefaultSetting somewhere +func AddDefaultFilter(fc FilterChain) { + settingMutex.Lock() + defer settingMutex.Unlock() + if defaultSetting.FilterChains == nil { + defaultSetting.FilterChains = make([]FilterChain, 0, 4) + } + defaultSetting.FilterChains = append(defaultSetting.FilterChains, fc) +} diff --git a/testing/client.go b/client/httplib/testing/client.go similarity index 86% rename from testing/client.go rename to client/httplib/testing/client.go index c3737e9c64..43e2e9680b 100644 --- a/testing/client.go +++ b/client/httplib/testing/client.go @@ -15,25 +15,26 @@ package testing import ( - "github.com/astaxie/beego/config" - "github.com/astaxie/beego/httplib" + "github.com/beego/beego/v2/client/httplib" ) -var port = "" -var baseURL = "http://localhost:" +var ( + port = "" + baseURL = "http://localhost:" +) // TestHTTPRequest beego test request client type TestHTTPRequest struct { httplib.BeegoHTTPRequest } +func SetTestingPort(p string) { + port = p +} + func getPort() string { if port == "" { - config, err := config.NewConfig("ini", "../conf/app.conf") - if err != nil { - return "8080" - } - port = config.String("httpport") + port = "8080" return port } return port diff --git a/orm/README.md b/client/orm/README.md similarity index 89% rename from orm/README.md rename to client/orm/README.md index 6e808d2ad3..a56a0fc1ef 100644 --- a/orm/README.md +++ b/client/orm/README.md @@ -1,6 +1,6 @@ # beego orm -[![Build Status](https://drone.io/github.com/astaxie/beego/status.png)](https://drone.io/github.com/astaxie/beego/latest) +[![Build Status](https://drone.io/github.com/beego/beego/v2/status.png)](https://drone.io/github.com/beego/beego/v2/latest) A powerful orm framework for go. @@ -27,7 +27,7 @@ more features please read the docs **Install:** - go get github.com/astaxie/beego/orm + go get github.com/beego/beego/v2/client/orm ## Changelog @@ -45,7 +45,7 @@ package main import ( "fmt" - "github.com/astaxie/beego/orm" + "github.com/beego/beego/v2/client/orm" _ "github.com/go-sql-driver/mysql" // import your used driver ) @@ -151,9 +151,3 @@ like this: note: not recommend use this in product env. -## Docs - -more details and examples in docs and test - -[documents](http://beego.me/docs/mvc/model/overview.md) - diff --git a/client/orm/clauses/const.go b/client/orm/clauses/const.go new file mode 100644 index 0000000000..38a6d55696 --- /dev/null +++ b/client/orm/clauses/const.go @@ -0,0 +1,6 @@ +package clauses + +const ( + ExprSep = "__" + ExprDot = "." +) diff --git a/client/orm/clauses/order_clause/order.go b/client/orm/clauses/order_clause/order.go new file mode 100644 index 0000000000..bdb2d1ca08 --- /dev/null +++ b/client/orm/clauses/order_clause/order.go @@ -0,0 +1,104 @@ +package order_clause + +import ( + "strings" + + "github.com/beego/beego/v2/client/orm/clauses" +) + +type Sort int8 + +const ( + None Sort = 0 + Ascending Sort = 1 + Descending Sort = 2 +) + +type Option func(order *Order) + +type Order struct { + column string + sort Sort + isRaw bool +} + +func Clause(options ...Option) *Order { + o := &Order{} + for _, option := range options { + option(o) + } + + return o +} + +func (o *Order) GetColumn() string { + return o.column +} + +func (o *Order) GetSort() Sort { + return o.sort +} + +func (o *Order) SortString() string { + switch o.GetSort() { + case Ascending: + return "ASC" + case Descending: + return "DESC" + } + + return `` +} + +func (o *Order) IsRaw() bool { + return o.isRaw +} + +func ParseOrder(expressions ...string) []*Order { + var orders []*Order + for _, expression := range expressions { + sort := Ascending + column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot) + if column[0] == '-' { + sort = Descending + column = column[1:] + } + + orders = append(orders, &Order{ + column: column, + sort: sort, + }) + } + + return orders +} + +func Column(column string) Option { + return func(order *Order) { + order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot) + } +} + +func sort(sort Sort) Option { + return func(order *Order) { + order.sort = sort + } +} + +func SortAscending() Option { + return sort(Ascending) +} + +func SortDescending() Option { + return sort(Descending) +} + +func SortNone() Option { + return sort(None) +} + +func Raw() Option { + return func(order *Order) { + order.isRaw = true + } +} diff --git a/client/orm/clauses/order_clause/order_test.go b/client/orm/clauses/order_clause/order_test.go new file mode 100644 index 0000000000..9f65870265 --- /dev/null +++ b/client/orm/clauses/order_clause/order_test.go @@ -0,0 +1,137 @@ +package order_clause + +import ( + "testing" +) + +func TestClause(t *testing.T) { + column := `a` + + o := Clause( + Column(column), + ) + + if o.GetColumn() != column { + t.Error() + } +} + +func TestSortAscending(t *testing.T) { + o := Clause( + SortAscending(), + ) + + if o.GetSort() != Ascending { + t.Error() + } +} + +func TestSortDescending(t *testing.T) { + o := Clause( + SortDescending(), + ) + + if o.GetSort() != Descending { + t.Error() + } +} + +func TestSortNone(t *testing.T) { + o1 := Clause( + SortNone(), + ) + + if o1.GetSort() != None { + t.Error() + } + + o2 := Clause() + + if o2.GetSort() != None { + t.Error() + } +} + +func TestRaw(t *testing.T) { + o1 := Clause() + + if o1.IsRaw() { + t.Error() + } + + o2 := Clause( + Raw(), + ) + + if !o2.IsRaw() { + t.Error() + } +} + +func TestColumn(t *testing.T) { + o1 := Clause( + Column(`aaa`), + ) + + if o1.GetColumn() != `aaa` { + t.Error() + } +} + +func TestParseOrder(t *testing.T) { + orders := ParseOrder( + `-user__status`, + `status`, + `user__status`, + ) + + t.Log(orders) + + if orders[0].GetSort() != Descending { + t.Error() + } + + if orders[0].GetColumn() != `user.status` { + t.Error() + } + + if orders[1].GetColumn() != `status` { + t.Error() + } + + if orders[1].GetSort() != Ascending { + t.Error() + } + + if orders[2].GetColumn() != `user.status` { + t.Error() + } +} + +func TestOrderGetColumn(t *testing.T) { + o := Clause( + Column(`user__id`), + ) + if o.GetColumn() != `user.id` { + t.Error() + } +} + +func TestSortString(t *testing.T) { + template := "got: %s, want: %s" + + o1 := Clause(sort(Sort(1))) + if o1.SortString() != "ASC" { + t.Errorf(template, o1.SortString(), "ASC") + } + + o2 := Clause(sort(Sort(2))) + if o2.SortString() != "DESC" { + t.Errorf(template, o2.SortString(), "DESC") + } + + o3 := Clause(sort(Sort(3))) + if o3.SortString() != `` { + t.Errorf(template, o3.SortString(), ``) + } +} diff --git a/orm/cmd.go b/client/orm/cmd.go similarity index 67% rename from orm/cmd.go rename to client/orm/cmd.go index 0ff4dc40d0..56aacaeeaa 100644 --- a/orm/cmd.go +++ b/client/orm/cmd.go @@ -15,10 +15,15 @@ package orm import ( + "context" "flag" "fmt" "os" "strings" + + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" ) type commander interface { @@ -26,9 +31,7 @@ type commander interface { Run() error } -var ( - commands = make(map[string]commander) -) +var commands = make(map[string]commander) // print help. func printHelp(errs ...string) { @@ -46,7 +49,7 @@ func printHelp(errs ...string) { os.Exit(2) } -// RunCommand listen for orm command and then run it if command arguments passed. +// RunCommand listens for orm command and runs if command arguments have been passed. func RunCommand() { if len(os.Args) < 2 || os.Args[1] != "orm" { return @@ -54,7 +57,7 @@ func RunCommand() { BootStrap() - args := argString(os.Args[2:]) + args := utils.ArgString(os.Args[2:]) name := args.Get(0) if name == "help" { @@ -76,14 +79,14 @@ func RunCommand() { // sync database struct command interface. type commandSyncDb struct { - al *alias + al *DB force bool verbose bool noInfo bool rtOnError bool } -// parse orm command line arguments. +// Parse the orm command line arguments. func (d *commandSyncDb) Parse(args []string) { var name string @@ -93,23 +96,27 @@ func (d *commandSyncDb) Parse(args []string) { flagSet.BoolVar(&d.verbose, "v", false, "verbose info") flagSet.Parse(args) - d.al = getDbAlias(name) + d.al = getDB(name) } -// run orm line command. +// Run orm line command. func (d *commandSyncDb) Run() error { var drops []string + var err error if d.force { - drops = getDbDropSQL(d.al) + drops, err = getDbDropSQL(models.DefaultModelCache, d.al) + if err != nil { + return err + } } db := d.al.DB - if d.force { - for i, mi := range modelCache.allOrdered() { + if d.force && len(drops) > 0 { + for i, mi := range models.DefaultModelCache.AllOrdered() { query := drops[i] if !d.noInfo { - fmt.Printf("drop table `%s`\n", mi.table) + fmt.Printf("drop table `%s`\n", mi.Table) } _, err := db.Exec(query) if d.verbose { @@ -124,9 +131,12 @@ func (d *commandSyncDb) Run() error { } } - sqls, indexes := getDbCreateSQL(d.al) + createQueries, indexes, err := getDbCreateSQL(models.DefaultModelCache, d.al) + if err != nil { + return err + } - tables, err := d.al.DbBaser.GetTables(db) + tables, err := d.al.dbBaser.GetTables(db) if err != nil { if d.rtOnError { return err @@ -134,14 +144,21 @@ func (d *commandSyncDb) Run() error { fmt.Printf(" %s\n", err.Error()) } - for i, mi := range modelCache.allOrdered() { - if tables[mi.table] { + ctx := context.Background() + for i, mi := range models.DefaultModelCache.AllOrdered() { + + if !models.IsApplicableTableForDB(mi.AddrField, d.al.name) { + fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.Table, d.al.name) + continue + } + + if tables[mi.Table] { if !d.noInfo { - fmt.Printf("table `%s` already exists, skip\n", mi.table) + fmt.Printf("table `%s` already exists, skip\n", mi.Table) } - var fields []*fieldInfo - columns, err := d.al.DbBaser.GetColumns(db, mi.table) + var fields []*models.FieldInfo + columns, err := d.al.dbBaser.GetColumns(ctx, db, mi.Table) if err != nil { if d.rtOnError { return err @@ -149,8 +166,8 @@ func (d *commandSyncDb) Run() error { fmt.Printf(" %s\n", err.Error()) } - for _, fi := range mi.fields.fieldsDB { - if _, ok := columns[fi.column]; !ok { + for _, fi := range mi.Fields.FieldsDB { + if _, ok := columns[fi.Column]; !ok { fields = append(fields, fi) } } @@ -159,7 +176,7 @@ func (d *commandSyncDb) Run() error { query := getColumnAddQuery(d.al, fi) if !d.noInfo { - fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table) + fmt.Printf("add column `%s` for table `%s`\n", fi.FullName, mi.Table) } _, err := db.Exec(query) @@ -174,8 +191,8 @@ func (d *commandSyncDb) Run() error { } } - for _, idx := range indexes[mi.table] { - if !d.al.DbBaser.IndexExists(db, idx.Table, idx.Name) { + for _, idx := range indexes[mi.Table] { + if !d.al.dbBaser.IndexExists(ctx, db, idx.Table, idx.Name) { if !d.noInfo { fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table) } @@ -198,11 +215,11 @@ func (d *commandSyncDb) Run() error { } if !d.noInfo { - fmt.Printf("create table `%s` \n", mi.table) + fmt.Printf("create table `%s` \n", mi.Table) } - queries := []string{sqls[i]} - for _, idx := range indexes[mi.table] { + queries := []string{createQueries[i]} + for _, idx := range indexes[mi.Table] { queries = append(queries, idx.SQL) } @@ -229,10 +246,10 @@ func (d *commandSyncDb) Run() error { // database creation commander interface implement. type commandSQLAll struct { - al *alias + al *DB } -// parse orm command line arguments. +// Parse orm command line arguments. func (d *commandSQLAll) Parse(args []string) { var name string @@ -240,16 +257,19 @@ func (d *commandSQLAll) Parse(args []string) { flagSet.StringVar(&name, "db", "default", "DataBase alias name") flagSet.Parse(args) - d.al = getDbAlias(name) + d.al = getDB(name) } -// run orm line command. +// Run orm line command. func (d *commandSQLAll) Run() error { - sqls, indexes := getDbCreateSQL(d.al) + createQueries, indexes, err := getDbCreateSQL(models.DefaultModelCache, d.al) + if err != nil { + return err + } var all []string - for i, mi := range modelCache.allOrdered() { - queries := []string{sqls[i]} - for _, idx := range indexes[mi.table] { + for i, mi := range models.DefaultModelCache.AllOrdered() { + queries := []string{createQueries[i]} + for _, idx := range indexes[mi.Table] { queries = append(queries, idx.SQL) } sql := strings.Join(queries, "\n") @@ -266,13 +286,13 @@ func init() { } // RunSyncdb run syncdb command line. -// name means table's alias name. default is "default". -// force means run next sql if the current is error. -// verbose means show all info when running command or not. +// name: Table's alias name (default is "default") +// force: Run the next sql command even if the current gave an error +// verbose: Print All information, useful for debugging func RunSyncdb(name string, force bool, verbose bool) error { BootStrap() - al := getDbAlias(name) + al := getDB(name) cmd := new(commandSyncDb) cmd.al = al cmd.force = force diff --git a/client/orm/cmd_utils.go b/client/orm/cmd_utils.go new file mode 100644 index 0000000000..feeac9df8a --- /dev/null +++ b/client/orm/cmd_utils.go @@ -0,0 +1,176 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "strings" + + "github.com/beego/beego/v2/client/orm/internal/models" +) + +type dbIndex struct { + Table string + Name string + SQL string +} + +// Get database column type string. +func getColumnTyp(db *DB, fi *models.FieldInfo) (col string) { + T := db.dbBaser.DbTypes() + fieldType := fi.FieldType + fieldSize := fi.Size + + defer func() { + // handling the placeholder, including %COL% + col = strings.ReplaceAll(col, "%COL%", fi.Column) + }() + +checkColumn: + switch fieldType { + case TypeBooleanField: + col = T["bool"] + case TypeVarCharField: + if db.driver == DRPostgres && fi.ToText { + col = T["string-text"] + } else { + col = fmt.Sprintf(T["string"], fieldSize) + } + case TypeCharField: + col = fmt.Sprintf(T["string-char"], fieldSize) + case TypeTextField: + col = T["string-text"] + case TypeTimeField: + col = T["time.Time-clock"] + case TypeDateField: + col = T["time.Time-date"] + case TypeDateTimeField: + // the precision of sqlite is not implemented + if db.driver == 2 || fi.TimePrecision == nil { + col = T["time.Time"] + } else { + s := T["time.Time-precision"] + col = fmt.Sprintf(s, *fi.TimePrecision) + } + + case TypeBitField: + col = T["int8"] + case TypeSmallIntegerField: + col = T["int16"] + case TypeIntegerField: + col = T["int32"] + case TypeBigIntegerField: + if db.driver == DRSqlite { + fieldType = TypeIntegerField + goto checkColumn + } + col = T["int64"] + case TypePositiveBitField: + col = T["uint8"] + case TypePositiveSmallIntegerField: + col = T["uint16"] + case TypePositiveIntegerField: + col = T["uint32"] + case TypePositiveBigIntegerField: + col = T["uint64"] + case TypeFloatField: + col = T["float64"] + case TypeDecimalField: + s := T["float64-decimal"] + if !strings.Contains(s, "%d") { + col = s + } else { + col = fmt.Sprintf(s, fi.Digits, fi.Decimals) + } + case TypeJSONField: + if db.driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["json"] + case TypeJsonbField: + if db.driver != DRPostgres { + fieldType = TypeVarCharField + goto checkColumn + } + col = T["jsonb"] + case RelForeignKey, RelOneToOne: + fieldType = fi.RelModelInfo.Fields.Pk.FieldType + fieldSize = fi.RelModelInfo.Fields.Pk.Size + goto checkColumn + } + + return +} + +// create alter sql string. +func getColumnAddQuery(al *DB, fi *models.FieldInfo) string { + Q := al.dbBaser.TableQuote() + typ := getColumnTyp(al, fi) + + if !fi.Null { + typ += " " + "NOT NULL" + } + + return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", + Q, fi.Mi.Table, Q, + Q, fi.Column, Q, + typ, getColumnDefault(fi), + ) +} + +// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands +func getColumnDefault(fi *models.FieldInfo) string { + var v, t, d string + + // Skip default attribute if field is in relations + if fi.Rel || fi.Reverse { + return v + } + + t = " DEFAULT '%s' " + + // These defaults will be useful if there no config value orm:"default" and NOT NULL is on + switch fi.FieldType { + case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: + return v + + case TypeBitField, TypeSmallIntegerField, TypeIntegerField, + TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, + TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, + TypeDecimalField: + t = " DEFAULT %s " + d = "0" + case TypeBooleanField: + t = " DEFAULT %s " + d = "FALSE" + case TypeJSONField, TypeJsonbField: + d = "{}" + } + + if fi.ColDefault { + if !fi.Initial.Exist() { + v = fmt.Sprintf(t, "") + } else { + v = fmt.Sprintf(t, fi.Initial.String()) + } + } else { + if !fi.Null { + v = fmt.Sprintf(t, d) + } + } + + return v +} diff --git a/client/orm/cmd_utils_test.go b/client/orm/cmd_utils_test.go new file mode 100644 index 0000000000..cb159e5e1f --- /dev/null +++ b/client/orm/cmd_utils_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm/internal/models" +) + +func Test_getColumnTyp(t *testing.T) { + testCases := []struct { + name string + fi *models.FieldInfo + al *DB + + wantCol string + }{ + { + // https://github.com/beego/beego/issues/5254 + name: "issue 5254", + fi: &models.FieldInfo{ + FieldType: TypePositiveIntegerField, + Column: "my_col", + }, + al: &DB{ + dbBaser: newdbBasePostgres(), + }, + wantCol: `bigint CHECK("my_col" >= 0)`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + col := getColumnTyp(tc.al, tc.fi) + assert.Equal(t, tc.wantCol, col) + }) + } +} diff --git a/orm/db.go b/client/orm/db.go similarity index 52% rename from orm/db.go rename to client/orm/db.go index 20d9c4cda1..76d2012742 100644 --- a/orm/db.go +++ b/client/orm/db.go @@ -15,55 +15,55 @@ package orm import ( + "context" "database/sql" "errors" "fmt" "reflect" "strings" "time" -) -const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" -) + "github.com/beego/beego/v2/client/orm/internal/buffers" -var ( - // ErrMissPK missing pk error - ErrMissPK = errors.New("missed pk value") -) + "github.com/beego/beego/v2/client/orm/internal/utils" -var ( - operators = map[string]bool{ - "exact": true, - "iexact": true, - "contains": true, - "icontains": true, - // "regex": true, - // "iregex": true, - "gt": true, - "gte": true, - "lt": true, - "lte": true, - "eq": true, - "nq": true, - "ne": true, - "startswith": true, - "endswith": true, - "istartswith": true, - "iendswith": true, - "in": true, - "between": true, - // "year": true, - // "month": true, - // "day": true, - // "week_day": true, - "isnull": true, - // "search": true, - } + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/client/orm/hints" ) +// ErrMissPK missing pk error +var ErrMissPK = errors.New("missed pk value") + +var operators = map[string]bool{ + "exact": true, + "iexact": true, + "strictexact": true, + "contains": true, + "icontains": true, + // "regex": true, + // "iregex": true, + "gt": true, + "gte": true, + "lt": true, + "lte": true, + "eq": true, + "nq": true, + "ne": true, + "startswith": true, + "endswith": true, + "istartswith": true, + "iendswith": true, + "in": true, + "between": true, + // "year": true, + // "month": true, + // "day": true, + // "week_day": true, + "isnull": true, + // "search": true, +} + // an instance of dbBaser interface/ type dbBase struct { ins dbBaser @@ -72,8 +72,8 @@ type dbBase struct { // check dbBase implements dbBaser interface. var _ dbBaser = new(dbBase) -// get struct columns values as interface slice. -func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { +// Get struct Columns values as interface slice. +func (d *dbBase) collectValues(mi *models.ModelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { if names == nil { ns := make([]string, 0, len(cols)) names = &ns @@ -81,13 +81,13 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, values = make([]interface{}, 0, len(cols)) for _, column := range cols { - var fi *fieldInfo - if fi, _ = mi.fields.GetByAny(column); fi != nil { - column = fi.column + var fi *models.FieldInfo + if fi, _ = mi.Fields.GetByAny(column); fi != nil { + column = fi.Column } else { - panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) + panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.FullName)) } - if !fi.dbcol || fi.auto && skipAuto { + if !fi.DBcol || fi.Auto && skipAuto { continue } value, err := d.collectFieldValue(mi, fi, ind, insert, tz) @@ -96,8 +96,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, } // ignore empty value auto field - if insert && fi.auto { - if fi.fieldType&IsPositiveIntegerField > 0 { + if insert && fi.Auto { + if fi.FieldType&IsPositiveIntegerField > 0 { if vu, ok := value.(uint64); !ok || vu == 0 { continue } @@ -106,7 +106,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, continue } } - autoFields = append(autoFields, fi.column) + autoFields = append(autoFields, fi.Column) } *names, values = append(*names, column), append(values, value) @@ -115,18 +115,18 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, return } -// get one field value in struct column as interface. -func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { +// Get one field value in struct column as interface. +func (d *dbBase) collectFieldValue(mi *models.ModelInfo, fi *models.FieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { var value interface{} - if fi.pk { + if fi.Pk { _, value, _ = getExistPk(mi, ind) } else { - field := ind.FieldByIndex(fi.fieldIndex) - if fi.isFielder { - f := field.Addr().Interface().(Fielder) + field := ind.FieldByIndex(fi.FieldIndex) + if fi.IsFielder { + f := field.Addr().Interface().(models.Fielder) value = f.RawValue() } else { - switch fi.fieldType { + switch fi.FieldType { case TypeBooleanField: if nb, ok := field.Interface().(sql.NullBool); ok { value = nil @@ -172,7 +172,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { vu := field.Interface() if _, ok := vu.(float32); ok { - value, _ = StrTo(ToStr(vu)).Float64() + value, _ = utils.StrTo(utils.ToStr(vu)).Float64() } else { value = field.Float() } @@ -189,7 +189,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } default: switch { - case fi.fieldType&IsPositiveIntegerField > 0: + case fi.FieldType&IsPositiveIntegerField > 0: if field.Kind() == reflect.Ptr { if field.IsNil() { value = nil @@ -199,7 +199,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { value = field.Uint() } - case fi.fieldType&IsIntegerField > 0: + case fi.FieldType&IsIntegerField > 0: if ni, ok := field.Interface().(sql.NullInt64); ok { value = nil if ni.Valid { @@ -214,25 +214,25 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } else { value = field.Int() } - case fi.fieldType&IsRelField > 0: + case fi.FieldType&IsRelField > 0: if field.IsNil() { value = nil } else { - if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { + if _, vu, ok := getExistPk(fi.RelModelInfo, reflect.Indirect(field)); ok { value = vu } else { value = nil } } - if !fi.null && value == nil { - return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) + if !fi.Null && value == nil { + return nil, fmt.Errorf("field `%s` cannot be NULL", fi.FullName) } } } } - switch fi.fieldType { + switch fi.FieldType { case TypeTimeField, TypeDateField, TypeDateTimeField: - if fi.autoNow || fi.autoNowAdd && insert { + if fi.AutoNow || fi.AutoNowAdd && insert { if insert { if t, ok := value.(time.Time); ok && !t.IsZero() { break @@ -241,8 +241,8 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val tnow := time.Now() d.ins.TimeToDB(&tnow, tz) value = tnow - if fi.isFielder { - f := field.Addr().Interface().(Fielder) + if fi.IsFielder { + f := field.Addr().Interface().(models.Fielder) f.SetRaw(tnow.In(DefaultTimeLoc)) } else if field.Kind() == reflect.Ptr { v := tnow.In(DefaultTimeLoc) @@ -253,8 +253,8 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val } case TypeJSONField, TypeJsonbField: if s, ok := value.(string); (ok && len(s) == 0) || value == nil { - if fi.colDefault && fi.initial.Exist() { - value = fi.initial.String() + if fi.ColDefault && fi.Initial.Exist() { + value = fi.Initial.String() } else { value = nil } @@ -264,15 +264,15 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val return value, nil } -// create insert sql preparation statement object. -func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { +// PrepareInsert create insert sql preparation statement object. +func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *models.ModelInfo) (stmtQuerier, string, error) { Q := d.ins.TableQuote() - dbcols := make([]string, 0, len(mi.fields.dbcols)) - marks := make([]string, 0, len(mi.fields.dbcols)) - for _, fi := range mi.fields.fieldsDB { - if !fi.auto { - dbcols = append(dbcols, fi.column) + dbcols := make([]string, 0, len(mi.Fields.DBcols)) + marks := make([]string, 0, len(mi.Fields.DBcols)) + for _, fi := range mi.Fields.FieldsDB { + if !fi.Auto { + dbcols = append(dbcols, fi.Column) marks = append(marks, "?") } } @@ -280,19 +280,19 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, sep := fmt.Sprintf("%s, %s", Q, Q) columns := strings.Join(dbcols, sep) - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks) d.ins.ReplaceMarks(&query) d.ins.HasReturningID(mi, &query) - stmt, err := q.Prepare(query) + stmt, err := q.PrepareContext(ctx, query) return stmt, query, err } -// insert struct with prepared statement and given struct reflect value. -func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) +// InsertStmt insert struct with prepared statement and given struct reflect value. +func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, nil, tz) if err != nil { return 0, err } @@ -303,7 +303,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, err := row.Scan(&id) return id, err } - res, err := stmt.Exec(values...) + res, err := stmt.ExecContext(ctx, values...) if err == nil { return res.LastInsertId() } @@ -311,7 +311,7 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, } // query sql ,read records and persist in dbBaser. -func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { +func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { var whereCols []string var args []interface{} @@ -336,8 +336,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo Q := d.ins.TableQuote() sep := fmt.Sprintf("%s, %s", Q, Q) - sels := strings.Join(mi.fields.dbcols, sep) - colsNum := len(mi.fields.dbcols) + sels := strings.Join(mi.Fields.DBcols, sep) + colsNum := len(mi.Fields.DBcols) sep = fmt.Sprintf("%s = ? AND %s", Q, Q) wheres := strings.Join(whereCols, sep) @@ -347,7 +347,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo forUpdate = "FOR UPDATE" } - query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) + query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.Table, Q, Q, wheres, Q, forUpdate) refs := make([]interface{}, colsNum) for i := range refs { @@ -357,41 +357,73 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo d.ins.ReplaceMarks(&query) - row := q.QueryRow(query, args...) + row := q.QueryRowContext(ctx, query, args...) if err := row.Scan(refs...); err != nil { if err == sql.ErrNoRows { return ErrNoRows } return err } - elm := reflect.New(mi.addrField.Elem().Type()) + elm := reflect.New(mi.AddrField.Elem().Type()) + mind := reflect.Indirect(elm) + d.setColsValues(mi, &mind, mi.Fields.DBcols, refs, tz) + ind.Set(mind) + return nil +} + +func (d *dbBase) ReadRaw(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, query string, args ...any) error { + //TODO: Improve code + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + return err + } + cols, err := rows.Columns() + if err != nil { + return err + } + refs := make([]any, len(cols)) + for i := range refs { + var ref any + refs[i] = &ref + } + if !rows.Next() { + return ErrNoRows + } + if err = rows.Scan(refs...); err != nil { + return err + } + elm := reflect.New(mi.AddrField.Elem().Type()) mind := reflect.Indirect(elm) - d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) + d.setColsValues(mi, &mind, mi.Fields.DBcols, refs, tz) ind.Set(mind) return nil } -// execute insert sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { - names := make([]string, 0, len(mi.fields.dbcols)) - values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) +func (*dbBase) ExecRaw(ctx context.Context, q dbQuerier, query string, args ...any) (sql.Result, error) { + return q.ExecContext(ctx, query, args...) +} + +// Insert execute insert sql dbQuerier with given struct reflect.Value. +func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location) (int64, error) { + names := make([]string, 0, len(mi.Fields.DBcols)) + values, autoFields, err := d.collectValues(mi, ind, mi.Fields.DBcols, false, true, &names, tz) if err != nil { return 0, err } - id, err := d.InsertValue(q, mi, false, names, values) + id, err := d.InsertValue(ctx, q, mi, false, names, values) if err != nil { return 0, err } if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return id, err } -// multi-insert sql with given slice struct reflect.Value. -func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { +// InsertMulti multi-insert sql with given slice struct reflect.Value. +func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *models.ModelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { var ( cnt int64 nums int @@ -399,32 +431,25 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul names []string ) - // typ := reflect.Indirect(mi.addrField).Type() - length, autoFields := sind.Len(), make([]string, 0, 1) for i := 1; i <= length; i++ { ind := reflect.Indirect(sind.Index(i - 1)) - // Is this needed ? - // if !ind.Type().AssignableTo(typ) { - // return cnt, ErrArgs - // } - if i == 1 { var ( vus []interface{} err error ) - vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) + vus, autoFields, err = d.collectValues(mi, ind, mi.Fields.DBcols, false, true, &names, tz) if err != nil { return cnt, err } values = make([]interface{}, bulk*len(vus)) nums += copy(values, vus) } else { - vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) + vus, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, false, true, nil, tz) if err != nil { return cnt, err } @@ -437,7 +462,7 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul } if i > 1 && i%bulk == 0 || length == i { - num, err := d.InsertValue(q, mi, true, names, values[:nums]) + num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums]) if err != nil { return cnt, err } @@ -448,73 +473,145 @@ func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bul var err error if len(autoFields) > 0 { - err = d.ins.setval(q, mi, autoFields) + err = d.ins.setval(ctx, q, mi, autoFields) } return cnt, err } -// execute insert sql with given struct and given values. +// InsertValue execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { + query := d.InsertValueSQL(names, values, isMulti, mi) + + if isMulti || !d.ins.HasReturningID(mi, &query) { + res, err := q.ExecContext(ctx, query, values...) + if err == nil { + if isMulti { + return res.RowsAffected() + } + + lastInsertId, err := res.LastInsertId() + if err != nil { + DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + return lastInsertId, ErrLastInsertIdUnavailable + } else { + return lastInsertId, nil + } + } + return 0, err + } + row := q.QueryRowContext(ctx, query, values...) + var id int64 + err := row.Scan(&id) + return id, err +} + +func (d *dbBase) InsertValueSQL(names []string, values []interface{}, isMulti bool, mi *models.ModelInfo) string { + buf := buffers.Get() + defer buffers.Put(buf) + Q := d.ins.TableQuote() + _, _ = buf.WriteString("INSERT INTO ") + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(Q) + + _, _ = buf.WriteString(" (") + for i, name := range names { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(name) + _, _ = buf.WriteString(Q) + } + _, _ = buf.WriteString(") VALUES (") + marks := make([]string, len(names)) for i := range marks { marks[i] = "?" } - - sep := fmt.Sprintf("%s, %s", Q, Q) qmarks := strings.Join(marks, ", ") - columns := strings.Join(names, sep) + + _, _ = buf.WriteString(qmarks) multi := len(values) / len(names) - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks + if isMulti && multi > 1 { + for i := 0; i < multi-1; i++ { + _, _ = buf.WriteString("), (") + _, _ = buf.WriteString(qmarks) + } } - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + _ = buf.WriteByte(')') + query := buf.String() d.ins.ReplaceMarks(&query) - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + return query +} + +// InsertOrUpdate a row +// If your primary key or unique column conflict will update +// If no will insert +func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *DB, args ...string) (int64, error) { + + names := make([]string, 0, len(mi.Fields.DBcols)-1) + + values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.tz) + if err != nil { + return 0, err + } + + query, err := d.InsertOrUpdateSQL(names, &values, mi, a, args...) + + if err != nil { + return 0, err + } + + if !d.ins.HasReturningID(mi, &query) { + res, err := q.ExecContext(ctx, query, values...) if err == nil { - if isMulti { - return res.RowsAffected() + lastInsertId, err := res.LastInsertId() + if err != nil { + DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + return lastInsertId, ErrLastInsertIdUnavailable + } else { + return lastInsertId, nil } - return res.LastInsertId() } return 0, err } - row := q.QueryRow(query, values...) + + row := q.QueryRowContext(ctx, query, values...) var id int64 - err := row.Scan(&id) + err = row.Scan(&id) + if err != nil && err.Error() == `pq: syntax error at or near "ON"` { + err = fmt.Errorf("postgres version must 9.5 or higher") + } return id, err } -// InsertOrUpdate a row -// If your primary key or unique column conflict will update -// If no will insert -func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBase) InsertOrUpdateSQL(names []string, values *[]interface{}, mi *models.ModelInfo, a *DB, args ...string) (string, error) { + args0 := "" - iouStr := "" - argsMap := map[string]string{} - switch a.Driver { + + switch a.driver { case DRMySQL: - iouStr = "ON DUPLICATE KEY UPDATE" case DRPostgres: if len(args) == 0 { - return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) + return "", fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.driverName) } args0 = strings.ToLower(args[0]) - iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) default: - return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) + return "", fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.driverName) } - //Get on the key-value pairs + argsMap := map[string]string{} + // Get on the key-value pairs for _, v := range args { kv := strings.Split(v, "=") if len(kv) == 2 { @@ -522,85 +619,95 @@ func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a } } - isMulti := false - names := make([]string, 0, len(mi.fields.dbcols)-1) - Q := d.ins.TableQuote() - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) + quote := d.ins.TableQuote() - if err != nil { - return 0, err + buf := buffers.Get() + defer buffers.Put(buf) + + _, _ = buf.WriteString("INSERT INTO ") + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(" (") + + for i, name := range names { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(name) + _, _ = buf.WriteString(quote) + } + + _, _ = buf.WriteString(") VALUES (") + + for i := 0; i < len(names); i++ { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString("?") + } + + _, _ = buf.WriteString(") ") + + switch a.driver { + case DRMySQL: + _, _ = buf.WriteString("ON DUPLICATE KEY UPDATE ") + case DRPostgres: + _, _ = buf.WriteString("ON CONFLICT (") + _, _ = buf.WriteString(args0) + _, _ = buf.WriteString(") DO UPDATE SET ") } - marks := make([]string, len(names)) - updateValues := make([]interface{}, 0) - updates := make([]string, len(names)) var conflitValue interface{} for i, v := range names { + if i > 0 { + _, _ = buf.WriteString(", ") + } // identifier in database may not be case-sensitive, so quote it - v = fmt.Sprintf("%s%s%s", Q, v, Q) - marks[i] = "?" + v = fmt.Sprintf("%s%s%s", quote, v, quote) valueStr := argsMap[strings.ToLower(v)] if v == args0 { - conflitValue = values[i] + conflitValue = (*values)[i] } if valueStr != "" { - switch a.Driver { + switch a.driver { case DRMySQL: - updates[i] = v + "=" + valueStr + _, _ = buf.WriteString(v) + _, _ = buf.WriteString("=") + _, _ = buf.WriteString(valueStr) case DRPostgres: if conflitValue != nil { - //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) - updateValues = append(updateValues, conflitValue) + // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + _, _ = buf.WriteString(v) + _, _ = buf.WriteString("=(select ") + _, _ = buf.WriteString(valueStr) + _, _ = buf.WriteString(" from ") + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(" where ") + _, _ = buf.WriteString(args0) + _, _ = buf.WriteString(" = ? )") + *values = append(*values, conflitValue) } else { - return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) + return "", fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) } } } else { - updates[i] = v + "=?" - updateValues = append(updateValues, values[i]) + _, _ = buf.WriteString(v) + _, _ = buf.WriteString("=?") + *values = append(*values, (*values)[i]) } } - values = append(values, updateValues...) - - sep := fmt.Sprintf("%s, %s", Q, Q) - qmarks := strings.Join(marks, ", ") - qupdates := strings.Join(updates, ", ") - columns := strings.Join(names, sep) - - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - //conflitValue maybe is a int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + query := buf.String() d.ins.ReplaceMarks(&query) - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) - if err == nil { - if isMulti { - return res.RowsAffected() - } - return res.LastInsertId() - } - return 0, err - } - - row := q.QueryRow(query, values...) - var id int64 - err = row.Scan(&id) - if err != nil && err.Error() == `pq: syntax error at or near "ON"` { - err = fmt.Errorf("postgres version must 9.5 or higher") - } - return id, err + return query, nil } -// execute update sql dbQuerier with given struct reflect.Value. -func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +// Update execute update sql dbQuerier with given struct reflect.Value. +func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { pkName, pkValue, ok := getExistPk(mi, ind) if !ok { return 0, ErrMissPK @@ -608,10 +715,10 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. var setNames []string - // if specify cols length is zero, then commit all columns. + // if specify cols length is zero, then commit All Columns. if len(cols) == 0 { - cols = mi.fields.dbcols - setNames = make([]string, 0, len(mi.fields.dbcols)-1) + cols = mi.Fields.DBcols + setNames = make([]string, 0, len(mi.Fields.DBcols)-1) } else { setNames = make([]string, 0, len(cols)) } @@ -621,41 +728,79 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, err } - var find bool + var findAutoNowAdd, findAutoNow bool var index int for i, col := range setNames { - if mi.fields.GetByColumn(col).autoNowAdd { + if mi.Fields.GetByColumn(col).AutoNowAdd { index = i - find = true + findAutoNowAdd = true + } + if mi.Fields.GetByColumn(col).AutoNow { + findAutoNow = true } } - - if find { + if findAutoNowAdd { setNames = append(setNames[0:index], setNames[index+1:]...) setValues = append(setValues[0:index], setValues[index+1:]...) } + if !findAutoNow { + for col, info := range mi.Fields.Columns { + if info.AutoNow { + setNames = append(setNames, col) + setValues = append(setValues, time.Now()) + } + } + } + setValues = append(setValues, pkValue) + query := d.UpdateSQL(setNames, pkName, mi) + + res, err := q.ExecContext(ctx, query, setValues...) + if err == nil { + return res.RowsAffected() + } + return 0, err +} + +func (d *dbBase) UpdateSQL(setNames []string, pkName string, mi *models.ModelInfo) string { + buf := buffers.Get() + defer buffers.Put(buf) + Q := d.ins.TableQuote() - sep := fmt.Sprintf("%s = ?, %s", Q, Q) - setColumns := strings.Join(setNames, sep) + _, _ = buf.WriteString("UPDATE ") + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" SET ") + + for i, name := range setNames { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(name) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" = ?") + } - query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) + _, _ = buf.WriteString(" WHERE ") + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(pkName) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" = ?") + query := buf.String() d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, setValues...) - if err == nil { - return res.RowsAffected() - } - return 0, err + return query } -// execute delete sql dbQuerier with given struct reflect.Value. +// Delete execute delete sql dbQuerier with given struct reflect.Value. // delete index is pk. -func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { +func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { var whereCols []string var args []interface{} // if specify cols length > 0, then use it for where condition. @@ -676,29 +821,16 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. args = append(args, pkValue) } - Q := d.ins.TableQuote() - - sep := fmt.Sprintf("%s = ? AND %s", Q, Q) - wheres := strings.Join(whereCols, sep) - - query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) + query := d.DeleteSQL(whereCols, mi) - d.ins.ReplaceMarks(&query) - res, err := q.Exec(query, args...) + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) - } else { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) - } - } - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -708,16 +840,45 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. return 0, err } -// update table-related record by querySet. +func (d *dbBase) DeleteSQL(whereCols []string, mi *models.ModelInfo) string { + buf := buffers.Get() + defer buffers.Put(buf) + + Q := d.ins.TableQuote() + + _, _ = buf.WriteString("DELETE FROM ") + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" WHERE ") + + for i, col := range whereCols { + if i > 0 { + _, _ = buf.WriteString(" AND ") + } + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(col) + _, _ = buf.WriteString(Q) + _, _ = buf.WriteString(" = ?") + } + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query +} + +// UpdateBatch update table-related record by querySet. // need querySet not struct reflect.Value to update related records. -func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { +func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { columns := make([]string, 0, len(params)) values := make([]interface{}, 0, len(params)) for col, val := range params { - if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { + if fi, ok := mi.Fields.GetByAny(col); !ok || !fi.DBcol { panic(fmt.Errorf("wrong field/column name `%s`", col)) } else { - columns = append(columns, fi.column) + columns = append(columns, fi.Column) values = append(values, val) } } @@ -727,8 +888,10 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con } tables := newDbTables(mi, d.ins) + var specifyIndexes string if qs != nil { tables.parseRelated(qs.related, qs.relDepth) + specifyIndexes = tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) } where, args := tables.getCondSQL(cond, false, tz) @@ -737,93 +900,157 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con join := tables.getJoinSQL() - var query, T string + query := d.UpdateBatchSQL(mi, columns, values, specifyIndexes, join, where) - Q := d.ins.TableQuote() + res, err := q.ExecContext(ctx, query, values...) + if err == nil { + return res.RowsAffected() + } + return 0, err +} + +func (d *dbBase) UpdateBatchSQL(mi *models.ModelInfo, cols []string, values []interface{}, specifyIndexes, join, where string) string { + quote := d.ins.TableQuote() + + buf := buffers.Get() + defer buffers.Put(buf) + + _, _ = buf.WriteString("UPDATE ") + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(quote) + + if d.ins.SupportUpdateJoin() { + _, _ = buf.WriteString(" T0 ") + _, _ = buf.WriteString(specifyIndexes) + _, _ = buf.WriteString(join) + + d.buildSetSQL(buf, cols, values) + + _, _ = buf.WriteString(" ") + _, _ = buf.WriteString(where) + } else { + _, _ = buf.WriteString(" ") + + d.buildSetSQL(buf, cols, values) + + _, _ = buf.WriteString(" WHERE ") + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(mi.Fields.Pk.Column) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(" IN ( ") + _, _ = buf.WriteString("SELECT T0.") + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(mi.Fields.Pk.Column) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(" FROM ") + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(" T0 ") + _, _ = buf.WriteString(specifyIndexes) + _, _ = buf.WriteString(join) + _, _ = buf.WriteString(where) + _, _ = buf.WriteString(" )") + } + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query +} + +func (d *dbBase) buildSetSQL(buf buffers.Buffer, cols []string, values []interface{}) { + + var owner string + + quote := d.ins.TableQuote() if d.ins.SupportUpdateJoin() { - T = "T0." + owner = "T0." } - cols := make([]string, 0, len(columns)) + _, _ = buf.WriteString("SET ") - for i, v := range columns { - col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) + for i, v := range cols { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(owner) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(v) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(" = ") if c, ok := values[i].(colValue); ok { + _, _ = buf.WriteString(owner) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(v) + _, _ = buf.WriteString(quote) switch c.opt { case ColAdd: - cols = append(cols, col+" = "+col+" + ?") + _, _ = buf.WriteString(" + ?") case ColMinus: - cols = append(cols, col+" = "+col+" - ?") + _, _ = buf.WriteString(" - ?") case ColMultiply: - cols = append(cols, col+" = "+col+" * ?") + _, _ = buf.WriteString(" * ?") case ColExcept: - cols = append(cols, col+" = "+col+" / ?") + _, _ = buf.WriteString(" / ?") + case ColBitAnd: + _, _ = buf.WriteString(" & ?") + case ColBitRShift: + _, _ = buf.WriteString(" >> ?") + case ColBitLShift: + _, _ = buf.WriteString(" << ?") + case ColBitXOR: + _, _ = buf.WriteString(" ^ ?") + case ColBitOr: + _, _ = buf.WriteString(" | ?") } values[i] = c.value } else { - cols = append(cols, col+" = ?") + _, _ = buf.WriteString("?") } } - - sets := strings.Join(cols, ", ") + " " - - if d.ins.SupportUpdateJoin() { - query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) - } else { - supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) - query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) - } - - d.ins.ReplaceMarks(&query) - var err error - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, values...) - } else { - res, err = q.Exec(query, values...) - } - if err == nil { - return res.RowsAffected() - } - return 0, err } // delete related records. // do UpdateBanch or DeleteBanch by condition of tables' relationship. -func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { - for _, fi := range mi.fields.fieldsReverse { - fi = fi.reverseFieldInfo - switch fi.onDelete { - case odCascade: - cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) +func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *models.ModelInfo, args []interface{}, tz *time.Location) error { + for _, fi := range mi.Fields.FieldsReverse { + fi = fi.ReverseFieldInfo + switch fi.OnDelete { + case models.OdCascade: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.Name), args...) + _, err := d.DeleteBatch(ctx, q, nil, fi.Mi, cond, tz) if err != nil { return err } - case odSetDefault, odSetNULL: - cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) - params := Params{fi.column: nil} - if fi.onDelete == odSetDefault { - params[fi.column] = fi.initial.String() + case models.OdSetDefault, models.OdSetNULL: + cond := NewCondition().And(fmt.Sprintf("%s__in", fi.Name), args...) + params := Params{fi.Column: nil} + if fi.OnDelete == models.OdSetDefault { + params[fi.Column] = fi.Initial.String() } - _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) + _, err := d.UpdateBatch(ctx, q, nil, fi.Mi, cond, params, tz) if err != nil { return err } - case odDoNothing: + case models.OdDoNothing: } } return nil } -// delete table-related records. -func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { +// DeleteBatch delete table-related records. +func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (int64, error) { tables := newDbTables(mi, d.ins) tables.skipEnd = true + var specifyIndexes string if qs != nil { tables.parseRelated(qs.related, qs.relDepth) + specifyIndexes = tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) } if cond == nil || cond.IsEmpty() { @@ -835,13 +1062,13 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con where, args := tables.getCondSQL(cond, false, tz) join := tables.getJoinSQL() - cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) - query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) + cols := fmt.Sprintf("T0.%s%s%s", Q, mi.Fields.Pk.Column, Q) + query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.Table, Q, specifyIndexes, join, where) d.ins.ReplaceMarks(&query) var rs *sql.Rows - r, err := q.Query(query, args...) + r, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -855,7 +1082,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con if err := rs.Scan(&ref); err != nil { return 0, err } - pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) + pkValue, err := d.convertValueFromDB(mi.Fields.Pk, reflect.ValueOf(ref).Interface(), tz) if err != nil { return 0, err } @@ -863,6 +1090,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con cnt++ } + if err = rs.Err(); err != nil { + return 0, err + } + if cnt == 0 { return 0, nil } @@ -872,22 +1103,17 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con marks[i] = "?" } sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) - query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) + query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.Table, Q, Q, mi.Fields.Pk.Column, Q, sqlIn) d.ins.ReplaceMarks(&query) - var res sql.Result - if qs != nil && qs.forContext { - res, err = q.ExecContext(qs.ctx, query, args...) - } else { - res, err = q.Exec(query, args...) - } + res, err := q.ExecContext(ctx, query, args...) if err == nil { num, err := res.RowsAffected() if err != nil { return 0, err } if num > 0 { - err := d.deleteRels(q, mi, args, tz) + err := d.deleteRels(ctx, q, mi, args, tz) if err != nil { return num, err } @@ -897,15 +1123,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con return 0, err } -// read related records. -func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { - +// ReadBatch read related records. +func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs querySet, mi *models.ModelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { val := reflect.ValueOf(container) ind := reflect.Indirect(val) - errTyp := true + unregister := true one := true isPtr := true + name := "" if val.Kind() == reflect.Ptr { fn := "" @@ -914,30 +1140,23 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi typ := ind.Type().Elem() switch typ.Kind() { case reflect.Ptr: - fn = getFullName(typ.Elem()) + fn = models.GetFullName(typ.Elem()) case reflect.Struct: isPtr = false - fn = getFullName(typ) + fn = models.GetFullName(typ) + name = models.GetTableName(reflect.New(typ)) } } else { - fn = getFullName(ind.Type()) + fn = models.GetFullName(ind.Type()) + name = models.GetTableName(ind) } - errTyp = fn != mi.fullName + unregister = fn != mi.FullName } - if errTyp { - if one { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) - } else { - panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) - } + if unregister { + RegisterModel(container) } - rlimit := qs.limit - offset := qs.offset - - Q := d.ins.TableQuote() - var tCols []string if len(cols) > 0 { hasRel := len(qs.related) > 0 || qs.relDepth > 0 @@ -947,73 +1166,53 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi maps = make(map[string]bool) } for _, col := range cols { - if fi, ok := mi.fields.GetByAny(col); ok { - tCols = append(tCols, fi.column) + if fi, ok := mi.Fields.GetByAny(col); ok { + tCols = append(tCols, fi.Column) if hasRel { - maps[fi.column] = true + maps[fi.Column] = true } } else { return 0, fmt.Errorf("wrong field/column name `%s`", col) } } if hasRel { - for _, fi := range mi.fields.fieldsDB { - if fi.fieldType&IsRelField > 0 { - if !maps[fi.column] { - tCols = append(tCols, fi.column) + for _, fi := range mi.Fields.FieldsDB { + if fi.FieldType&IsRelField > 0 { + if !maps[fi.Column] { + tCols = append(tCols, fi.Column) } } } } } else { - tCols = mi.fields.dbcols + tCols = mi.Fields.DBcols } - colsNum := len(tCols) - sep := fmt.Sprintf("%s, T0.%s", Q, Q) - sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) - tables := newDbTables(mi, d.ins) tables.parseRelated(qs.related, qs.relDepth) - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - orderBy := tables.getOrderSQL(qs.orders) - limit := tables.getLimitSQL(mi, offset, rlimit) - join := tables.getJoinSQL() + colsNum := len(tCols) for _, tbl := range tables.tables { if tbl.sel { - colsNum += len(tbl.mi.fields.dbcols) - sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) - sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) + colsNum += len(tbl.mi.Fields.DBcols) } } - sqlSelect := "SELECT" - if qs.distinct { - sqlSelect += " DISTINCT" - } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) + query, args := d.readBatchSQL(tables, tCols, cond, qs, mi, tz) - if qs.forupdate { - query += " FOR UPDATE" + rs, err := q.QueryContext(ctx, query, args...) + if err != nil { + return 0, err } - d.ins.ReplaceMarks(&query) + defer rs.Close() - var rs *sql.Rows - var err error - if qs != nil && qs.forContext { - rs, err = q.QueryContext(qs.ctx, query, args...) - if err != nil { - return 0, err - } - } else { - rs, err = q.Query(query, args...) - if err != nil { - return 0, err - } + slice := ind + if unregister { + mi, _ = models.DefaultModelCache.Get(name) + tCols = mi.Fields.DBcols + colsNum = len(tCols) } refs := make([]interface{}, colsNum) @@ -1021,11 +1220,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi var ref interface{} refs[i] = &ref } - - defer rs.Close() - - slice := ind - var cnt int64 for rs.Next() { if one && cnt == 0 || !one { @@ -1033,11 +1227,11 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi return 0, err } - elm := reflect.New(mi.addrField.Elem().Type()) + elm := reflect.New(mi.AddrField.Elem().Type()) mind := reflect.Indirect(elm) cacheV := make(map[string]*reflect.Value) - cacheM := make(map[string]*modelInfo) + cacheM := make(map[string]*models.ModelInfo) trefs := refs d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) @@ -1056,18 +1250,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi last = *val mmi = cacheM[names] } else { - fi := mmi.fields.GetByName(name) + fi := mmi.Fields.GetByName(name) lastm := mmi - mmi = fi.relModelInfo + mmi = fi.RelModelInfo field := last if last.Kind() != reflect.Invalid { - field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) + field = reflect.Indirect(last.FieldByIndex(fi.FieldIndex)) if field.IsValid() { - d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) - for _, fi := range mmi.fields.fieldsReverse { - if fi.inModel && fi.reverseFieldInfo.mi == lastm { - if fi.reverseFieldInfo != nil { - f := field.FieldByIndex(fi.fieldIndex) + d.setColsValues(mmi, &field, mmi.Fields.DBcols, trefs[:len(mmi.Fields.DBcols)], tz) + for _, fi := range mmi.Fields.FieldsReverse { + if fi.InModel && fi.ReverseFieldInfo.Mi == lastm { + if fi.ReverseFieldInfo != nil { + f := field.FieldByIndex(fi.FieldIndex) if f.Kind() == reflect.Ptr { f.Set(last.Addr()) } @@ -1081,7 +1275,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi cacheM[names] = mmi } } - trefs = trefs[len(mmi.fields.dbcols):] + trefs = trefs[len(mmi.Fields.DBcols):] } } @@ -1089,7 +1283,7 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ind.Set(mind) } else { if cnt == 0 { - // you can use a empty & caped container list + // you can use an empty & caped container list // orm will not replace it if ind.Len() != 0 { // if container is not empty @@ -1108,12 +1302,16 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi cnt++ } + if err = rs.Err(); err != nil { + return 0, err + } + if !one { if cnt > 0 { ind.Set(slice) } else { // when a result is empty and container is nil - // to set a empty container + // to Set an empty container if ind.IsNil() { ind.Set(reflect.MakeSlice(ind.Type(), 0, 0)) } @@ -1123,38 +1321,134 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi return cnt, nil } -// excute count sql and return count result int64. -func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { - tables := newDbTables(mi, d.ins) - tables.parseRelated(qs.related, qs.relDepth) +func (d *dbBase) readBatchSQL(tables *dbTables, tCols []string, cond *Condition, qs querySet, mi *models.ModelInfo, tz *time.Location) (string, []interface{}) { + cols := d.preProcCols(tCols) // pre process columns + + buf := buffers.Get() + defer buffers.Put(buf) + + args := d.readSQL(buf, tables, cols, cond, qs, mi, tz) + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query, args +} + +func (d *dbBase) preProcCols(cols []string) []string { + res := make([]string, len(cols)) + + quote := d.ins.TableQuote() + for i, col := range cols { + res[i] = fmt.Sprintf("T0.%s%s%s", quote, col, quote) + } + + return res +} + +// readSQL generate a select sql string and return args +// ReadBatch and ReadValues methods will reuse this method. +func (d *dbBase) readSQL(buf buffers.Buffer, tables *dbTables, tCols []string, cond *Condition, qs querySet, mi *models.ModelInfo, tz *time.Location) []interface{} { + + quote := d.ins.TableQuote() where, args := tables.getCondSQL(cond, false, tz) groupBy := tables.getGroupSQL(qs.groups) - tables.getOrderSQL(qs.orders) + orderBy := tables.getOrderSQL(qs.orders) + limit := tables.getLimitSQL(mi, qs.offset, qs.limit) join := tables.getJoinSQL() + specifyIndexes := tables.getIndexSql(mi.Table, qs.useIndex, qs.indexes) - Q := d.ins.TableQuote() - - query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) + _, _ = buf.WriteString("SELECT ") - if groupBy != "" { - query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) + if qs.distinct { + _, _ = buf.WriteString("DISTINCT ") } - d.ins.ReplaceMarks(&query) + if qs.aggregate == "" { + for i, tCol := range tCols { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(tCol) + } - var row *sql.Row - if qs != nil && qs.forContext { - row = q.QueryRowContext(qs.ctx, query, args...) + for _, tbl := range tables.tables { + if tbl.sel { + _, _ = buf.WriteString(", ") + for i, DBcol := range tbl.mi.Fields.DBcols { + if i > 0 { + _, _ = buf.WriteString(", ") + } + _, _ = buf.WriteString(tbl.index) + _, _ = buf.WriteString(".") + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(DBcol) + _, _ = buf.WriteString(quote) + } + } + } } else { - row = q.QueryRow(query, args...) + _, _ = buf.WriteString(qs.aggregate) } + + _, _ = buf.WriteString(" FROM ") + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(mi.Table) + _, _ = buf.WriteString(quote) + _, _ = buf.WriteString(" T0 ") + _, _ = buf.WriteString(specifyIndexes) + _, _ = buf.WriteString(join) + _, _ = buf.WriteString(where) + _, _ = buf.WriteString(groupBy) + _, _ = buf.WriteString(orderBy) + _, _ = buf.WriteString(limit) + + if qs.forUpdate { + _, _ = buf.WriteString(" FOR UPDATE") + } + + return args +} + +// Count excute count sql and return count result int64. +func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { + + query, args := d.countSQL(qs, mi, cond, tz) + + row := q.QueryRowContext(ctx, query, args...) err = row.Scan(&cnt) return } -// generate sql with replacing operator string placeholders and replaced values. -func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { +func (d *dbBase) countSQL(qs querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) { + tables := newDbTables(mi, d.ins) + tables.parseRelated(qs.related, qs.relDepth) + + buf := buffers.Get() + defer buffers.Put(buf) + + if len(qs.groups) > 0 { + _, _ = buf.WriteString("SELECT COUNT(*) FROM (") + } + + qs.aggregate = "COUNT(*)" + args := d.readSQL(buf, tables, nil, cond, qs, mi, tz) + + if len(qs.groups) > 0 { + _, _ = buf.WriteString(") AS T") + } + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query, args +} + +// GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values. +func (d *dbBase) GenerateOperatorSQL(mi *models.ModelInfo, fi *models.FieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { var sql string params := getFlatParams(fi, args, tz) @@ -1186,7 +1480,7 @@ func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator stri params[0] = "IS NULL" } case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": - param := strings.Replace(ToStr(arg), `%`, `\%`, -1) + param := strings.Replace(utils.ToStr(arg), `%`, `\%`, -1) switch operator { case "iexact": case "contains", "icontains": @@ -1213,19 +1507,19 @@ func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator stri return sql, params } -// gernerate sql string with inner function, such as UPPER(text). -func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { +// GenerateOperatorLeftCol gernerate sql string with inner function, such as UPPER(text). +func (d *dbBase) GenerateOperatorLeftCol(*models.FieldInfo, string, *string) { // default not use } -// set values to struct column. -func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { +// Set values to struct column. +func (d *dbBase) setColsValues(mi *models.ModelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { for i, column := range cols { val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() - fi := mi.fields.GetByColumn(column) + fi := mi.Fields.GetByColumn(column) - field := ind.FieldByIndex(fi.fieldIndex) + field := ind.FieldByIndex(fi.FieldIndex) value, err := d.convertValueFromDB(fi, val, tz) if err != nil { @@ -1241,7 +1535,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, } // convert value from database result to value following in field type. -func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { +func (d *dbBase) convertValueFromDB(fi *models.FieldInfo, val interface{}, tz *time.Location) (interface{}, error) { if val == nil { return nil, nil } @@ -1249,17 +1543,17 @@ func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Loc var value interface{} var tErr error - var str *StrTo + var str *utils.StrTo switch v := val.(type) { case []byte: - s := StrTo(string(v)) + s := utils.StrTo(string(v)) str = &s case string: - s := StrTo(v) + s := utils.StrTo(v) str = &s } - fieldType := fi.fieldType + fieldType := fi.FieldType setValue: switch { @@ -1270,7 +1564,7 @@ setValue: b := v == 1 value = b default: - s := StrTo(ToStr(v)) + s := utils.StrTo(utils.ToStr(v)) str = &s } } @@ -1284,7 +1578,7 @@ setValue: } case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: if str == nil { - value = ToStr(val) + value = utils.ToStr(val) } else { value = str.String() } @@ -1295,7 +1589,7 @@ setValue: d.ins.TimeFromDB(&t, tz) value = t default: - s := StrTo(ToStr(t)) + s := utils.StrTo(utils.ToStr(t)) str = &s } } @@ -1305,19 +1599,26 @@ setValue: t time.Time err error ) - if len(s) >= 19 { + + if fi.TimePrecision != nil && len(s) >= (20+*fi.TimePrecision) { + layout := utils.FormatDateTime + "." + for i := 0; i < *fi.TimePrecision; i++ { + layout += "0" + } + t, err = time.ParseInLocation(layout, s[:20+*fi.TimePrecision], tz) + } else if len(s) >= 19 { s = s[:19] - t, err = time.ParseInLocation(formatDateTime, s, tz) + t, err = time.ParseInLocation(utils.FormatDateTime, s, tz) } else if len(s) >= 10 { if len(s) > 10 { s = s[:10] } - t, err = time.ParseInLocation(formatDate, s, tz) + t, err = time.ParseInLocation(utils.FormatDate, s, tz) } else if len(s) >= 8 { if len(s) > 8 { s = s[:8] } - t, err = time.ParseInLocation(formatTime, s, tz) + t, err = time.ParseInLocation(utils.FormatTime, s, tz) } t = t.In(DefaultTimeLoc) @@ -1329,7 +1630,7 @@ setValue: } case fieldType&IsIntegerField > 0: if str == nil { - s := StrTo(ToStr(val)) + s := utils.StrTo(utils.ToStr(val)) str = &s } if str != nil { @@ -1370,7 +1671,7 @@ setValue: case float64: value = v default: - s := StrTo(ToStr(v)) + s := utils.StrTo(utils.ToStr(v)) str = &s } } @@ -1383,26 +1684,24 @@ setValue: value = v } case fieldType&IsRelField > 0: - fi = fi.relModelInfo.fields.pk - fieldType = fi.fieldType + fi = fi.RelModelInfo.Fields.Pk + fieldType = fi.FieldType goto setValue } end: if tErr != nil { - err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) + err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.AddrValue.Type(), fi.FullName, tErr) return nil, err } return value, nil - } -// set one value to struct column field. -func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { - - fieldType := fi.fieldType - isNative := !fi.isFielder +// Set one value to struct column field. +func (d *dbBase) setFieldValue(fi *models.FieldInfo, value interface{}, field reflect.Value) (interface{}, error) { + fieldType := fi.FieldType + isNative := !fi.IsFielder setValue: switch { @@ -1569,20 +1868,20 @@ setValue: } case fieldType&IsRelField > 0: if value != nil { - fieldType = fi.relModelInfo.fields.pk.fieldType - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + fieldType = fi.RelModelInfo.Fields.Pk.FieldType + mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type()) field.Set(mf) - f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + f := mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex) field = f goto setValue } } if !isNative { - fd := field.Addr().Interface().(Fielder) + fd := field.Addr().Interface().(models.Fielder) err := fd.SetRaw(value) if err != nil { - err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) + err = fmt.Errorf("converted value `%v` Set to Fielder `%s` failed, err: %s", value, fi.FullName, err) return nil, err } } @@ -1590,9 +1889,8 @@ setValue: return value, nil } -// query sql, read values , save to *[]ParamList. -func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { - +// ReadValues query sql, read values , save to *[]ParamList. +func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs querySet, mi *models.ModelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { var ( maps []Params lists []ParamsList @@ -1627,7 +1925,7 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond var ( cols []string - infos []*fieldInfo + infos []*models.FieldInfo ) hasExprs := len(exprs) > 0 @@ -1636,41 +1934,27 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond if hasExprs { cols = make([]string, 0, len(exprs)) - infos = make([]*fieldInfo, 0, len(exprs)) + infos = make([]*models.FieldInfo, 0, len(exprs)) for _, ex := range exprs { index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) if !suc { panic(fmt.Errorf("unknown field/column name `%s`", ex)) } - cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) + cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.Column, Q, Q, name, Q)) infos = append(infos, fi) } } else { - cols = make([]string, 0, len(mi.fields.dbcols)) - infos = make([]*fieldInfo, 0, len(exprs)) - for _, fi := range mi.fields.fieldsDB { - cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) + cols = make([]string, 0, len(mi.Fields.DBcols)) + infos = make([]*models.FieldInfo, 0, len(exprs)) + for _, fi := range mi.Fields.FieldsDB { + cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.Column, Q, Q, fi.Name, Q)) infos = append(infos, fi) } } - where, args := tables.getCondSQL(cond, false, tz) - groupBy := tables.getGroupSQL(qs.groups) - orderBy := tables.getOrderSQL(qs.orders) - limit := tables.getLimitSQL(mi, qs.offset, qs.limit) - join := tables.getJoinSQL() - - sels := strings.Join(cols, ", ") + query, args := d.readValuesSQL(tables, cols, qs, mi, cond, tz) - sqlSelect := "SELECT" - if qs.distinct { - sqlSelect += " DISTINCT" - } - query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) - - d.ins.ReplaceMarks(&query) - - rs, err := q.Query(query, args...) + rs, err := q.QueryContext(ctx, query, args...) if err != nil { return 0, err } @@ -1748,6 +2032,10 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond cnt++ } + if err = rs.Err(); err != nil { + return 0, err + } + switch v := container.(type) { case *[]Params: *v = maps @@ -1760,11 +2048,20 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond return cnt, nil } -func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { - return 0, nil +func (d *dbBase) readValuesSQL(tables *dbTables, cols []string, qs querySet, mi *models.ModelInfo, cond *Condition, tz *time.Location) (string, []interface{}) { + buf := buffers.Get() + defer buffers.Put(buf) + + args := d.readSQL(buf, tables, cols, cond, qs, mi, tz) + + query := buf.String() + + d.ins.ReplaceMarks(&query) + + return query, args } -// flag of update joined record. +// SupportUpdateJoin flag of update joined record. func (d *dbBase) SupportUpdateJoin() bool { return true } @@ -1773,42 +2070,42 @@ func (d *dbBase) MaxLimit() uint64 { return 18446744073709551615 } -// return quote. +// TableQuote return quote. func (d *dbBase) TableQuote() string { return "`" } -// replace value placeholder in parametered sql string. +// ReplaceMarks replace value placeholder in parametered sql string. func (d *dbBase) ReplaceMarks(query *string) { // default use `?` as mark, do nothing } // flag of RETURNING sql. -func (d *dbBase) HasReturningID(*modelInfo, *string) bool { +func (d *dbBase) HasReturningID(*models.ModelInfo, *string) bool { return false } // sync auto key -func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *models.ModelInfo, autoFields []string) error { return nil } -// convert time from db. +// TimeFromDB convert time from db. func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { *t = t.In(tz) } -// convert time to db. +// TimeToDB convert time to db. func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { *t = t.In(tz) } -// get database types. +// DbTypes Get database types. func (d *dbBase) DbTypes() map[string]string { return nil } -// gt all tables. +// GetTables gt All tables. func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { tables := make(map[string]bool) query := d.ins.ShowTablesQuery() @@ -1830,14 +2127,14 @@ func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { } } - return tables, nil + return tables, rows.Err() } -// get all cloumns in table. -func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +// GetColumns Get All cloumns in table. +func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { columns := make(map[string][3]string) query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return columns, err } @@ -1857,7 +2154,7 @@ func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, e columns[name] = [3]string{name, typ, null} } - return columns, nil + return columns, rows.Err() } // not implement. @@ -1876,6 +2173,32 @@ func (d *dbBase) ShowColumnsQuery(table string) string { } // not implement. -func (d *dbBase) IndexExists(dbQuerier, string, string) bool { +func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool { panic(ErrNotImplement) } + +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + var useWay string + + switch useIndex { + case hints.KeyUseIndex: + useWay = `USE` + case hints.KeyForceIndex: + useWay = `FORCE` + case hints.KeyIgnoreIndex: + useWay = `IGNORE` + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } + + return fmt.Sprintf(` %s INDEX(%s) `, useWay, strings.Join(s, `,`)) +} diff --git a/client/orm/db_alias.go b/client/orm/db_alias.go new file mode 100644 index 0000000000..18dd55e9a4 --- /dev/null +++ b/client/orm/db_alias.go @@ -0,0 +1,678 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "fmt" + "sync" + "time" + + lru "github.com/hashicorp/golang-lru" +) + +// DriverType database driver constant int. +type DriverType int + +// Enum the Database driver +const ( + _ DriverType = iota // int enum type + DRMySQL // mysql + DRSqlite // sqlite + DROracle // oracle + DRPostgres // pgsql + DRTiDB // TiDB +) + +// database driver string. +type driver string + +// Get type constant int of current driver.. +func (d driver) Type() DriverType { + a, _ := dataBaseCache.get(string(d)) + return a.driver +} + +// Get name of current driver +func (d driver) Name() string { + return string(d) +} + +// check driver iis implemented Driver interface or not. +var _ Driver = new(driver) + +var ( + dataBaseCache = &_dbCache{cache: make(map[string]*DB)} + drivers = map[string]DriverType{ + "mysql": DRMySQL, + "postgres": DRPostgres, + "sqlite3": DRSqlite, + "tidb": DRTiDB, + "oracle": DROracle, + "oci8": DROracle, // github.com/mattn/go-oci8 + "ora": DROracle, // https://github.com/rana/ora + } + dbBasers = map[DriverType]dbBaser{ + DRMySQL: newdbBaseMysql(), + DRSqlite: newdbBaseSqlite(), + DROracle: newdbBaseOracle(), + DRPostgres: newdbBasePostgres(), + DRTiDB: newdbBaseTidb(), + } +) + +// database alias cacher. +type _dbCache struct { + mux sync.RWMutex + cache map[string]*DB +} + +// add database db with original name. +func (ac *_dbCache) add(name string, db *DB) (added bool) { + ac.mux.Lock() + defer ac.mux.Unlock() + if _, ok := ac.cache[name]; !ok { + ac.cache[name] = db + added = true + } + return +} + +// get database alias if cached. +func (ac *_dbCache) get(name string) (db *DB, ok bool) { + ac.mux.RLock() + defer ac.mux.RUnlock() + db, ok = ac.cache[name] + return +} + +func (ac *_dbCache) getORSet(aliasName, driverName string, db *sql.DB, params ...DBOption) (al *DB, err error) { + ac.mux.RLock() + d, ok := ac.cache[aliasName] + ac.mux.RUnlock() + if !ok { + ac.mux.Lock() + defer ac.mux.Unlock() + al, err = newDB(aliasName, driverName, db, params...) + if err != nil { + return + } + ac.cache[aliasName] = d + } + return +} + +// get default alias. +func (ac *_dbCache) getDefault() (db *DB) { + db, _ = ac.get("default") + return +} + +type DB struct { + *sync.RWMutex + DB *sql.DB + stmtDecorators *lru.Cache + stmtDecoratorsLimit int + + name string + driver DriverType + driverName string + dataSource string + maxIdleConns int + maxOpenConns int + connMaxLifeTime time.Duration + connMaxIdleTime time.Duration + stmtCacheSize int + dbBaser dbBaser + tz *time.Location + engine string +} + +var ( + _ dbQuerier = new(DB) + _ txer = new(DB) +) + +func (d *DB) Begin() (*sql.Tx, error) { + return d.DB.Begin() +} + +func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return d.DB.BeginTx(ctx, opts) +} + +// su must call release to release *sql.Stmt after using +func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) { + d.RLock() + c, ok := d.stmtDecorators.Get(query) + if ok { + c.(*stmtDecorator).acquire() + d.RUnlock() + return c.(*stmtDecorator), nil + } + d.RUnlock() + + d.Lock() + c, ok = d.stmtDecorators.Get(query) + if ok { + c.(*stmtDecorator).acquire() + d.Unlock() + return c.(*stmtDecorator), nil + } + + stmt, err := d.Prepare(query) + if err != nil { + d.Unlock() + return nil, err + } + sd := newStmtDecorator(stmt) + sd.acquire() + d.stmtDecorators.Add(query, sd) + d.Unlock() + + return sd, nil +} + +func (d *DB) Prepare(query string) (*sql.Stmt, error) { + return d.DB.Prepare(query) +} + +func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return d.DB.PrepareContext(ctx, query) +} + +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + return d.ExecContext(context.Background(), query, args...) +} + +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if d.stmtDecorators == nil { + return d.DB.ExecContext(ctx, query, args...) + } + + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.ExecContext(ctx, args...) +} + +func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return d.QueryContext(context.Background(), query, args...) +} + +func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + if d.stmtDecorators == nil { + return d.DB.QueryContext(ctx, query, args...) + } + + sd, err := d.getStmtDecorator(query) + if err != nil { + return nil, err + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryContext(ctx, args...) +} + +func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row { + return d.QueryRowContext(context.Background(), query, args...) +} + +func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + if d.stmtDecorators == nil { + return d.DB.QueryRowContext(ctx, query, args...) + } + + sd, err := d.getStmtDecorator(query) + if err != nil { + panic(err) + } + stmt := sd.getStmt() + defer sd.release() + return stmt.QueryRowContext(ctx, args...) +} + +type TxDB struct { + tx *sql.Tx +} + +var ( + _ dbQuerier = new(TxDB) + _ txEnder = new(TxDB) +) + +func (t *TxDB) Commit() error { + return t.tx.Commit() +} + +func (t *TxDB) Rollback() error { + return t.tx.Rollback() +} + +func (t *TxDB) RollbackUnlessCommit() error { + err := t.tx.Rollback() + if err != sql.ErrTxDone { + return err + } + return nil +} + +var ( + _ dbQuerier = new(TxDB) + _ txEnder = new(TxDB) +) + +func (t *TxDB) Prepare(query string) (*sql.Stmt, error) { + return t.PrepareContext(context.Background(), query) +} + +func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + return t.tx.PrepareContext(ctx, query) +} + +func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) { + return t.ExecContext(context.Background(), query, args...) +} + +func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return t.tx.ExecContext(ctx, query, args...) +} + +func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) { + return t.QueryContext(context.Background(), query, args...) +} + +func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return t.tx.QueryContext(ctx, query, args...) +} + +func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row { + return t.QueryRowContext(context.Background(), query, args...) +} + +func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return t.tx.QueryRowContext(ctx, query, args...) +} + +func detectTZ(al *DB) { + // orm timezone system match database + // default use Local + al.tz = DefaultTimeLoc + + if al.driverName == "sphinx" { + return + } + + switch al.driver { + case DRMySQL: + row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") + var tz string + if err := row.Scan(&tz); err != nil { + DebugLog.Printf("Detect DB timezone: %s\n", err.Error()) + return + } + if len(tz) >= 8 { + if tz[0] != '-' { + tz = "+" + tz + } + t, err := time.Parse("-07:00:00", tz) + if err == nil { + if t.Location().String() != "" { + al.tz = t.Location() + } + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + } + } + + // Get default engine from current database + row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") + var engine string + var tx bool + if err := row.Scan(&engine, &tx); err != nil { + DebugLog.Printf("Detect DB engine: %s\n", err.Error()) + return + } + + if engine != "" { + al.engine = engine + } else { + al.engine = "INNODB" + } + + case DRSqlite, DROracle: + al.tz = time.UTC + + case DRPostgres: + row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") + var tz string + if err := row.Scan(&tz); err != nil { + DebugLog.Printf("Detect DB timezone: %s\n", err.Error()) + return + } + loc, err := time.LoadLocation(tz) + if err == nil { + al.tz = loc + } else { + DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) + } + } +} + +func addDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*DB, error) { + existErr := fmt.Errorf("DataBase DB name `%s` already registered, cannot reuse", aliasName) + if _, ok := dataBaseCache.get(aliasName); ok { + return nil, existErr + } + + al, err := newDB(aliasName, driverName, db, params...) + if err != nil { + return nil, err + } + + if !dataBaseCache.add(aliasName, al) { + return nil, existErr + } + + return al, nil +} + +func getORSetDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*DB, error) { + al, err := dataBaseCache.getORSet(aliasName, driverName, db, params...) + if err != nil { + return nil, err + } + + return al, nil +} + +func newDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*DB, error) { + //al := &alias{} + //al.DB = &DB{ + // RWMutex: new(sync.RWMutex), + // DB: db, + //} + + res := &DB{ + RWMutex: new(sync.RWMutex), + DB: db, + } + for _, p := range params { + p(res) + } + + var stmtCache *lru.Cache + var stmtCacheSize int + + if res.stmtCacheSize > 0 { + _stmtCache, errC := newStmtDecoratorLruWithEvict(res.stmtCacheSize) + if errC != nil { + return nil, errC + } else { + stmtCache = _stmtCache + stmtCacheSize = res.stmtCacheSize + } + } + + res.name = aliasName + res.driverName = driverName + res.stmtDecorators = stmtCache + res.stmtDecoratorsLimit = stmtCacheSize + + if dr, ok := drivers[driverName]; ok { + res.dbBaser = dbBasers[dr] + res.driver = dr + } else { + return nil, fmt.Errorf("driver name `%s` have not registered", driverName) + } + + err := db.Ping() + if err != nil { + return nil, fmt.Errorf("Register db Ping `%s`, %s", aliasName, err.Error()) + } + + detectTZ(res) + + return res, nil +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxIdleConns(aliasName string, maxIdleConns int) { + d := getDB(aliasName) + d.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +// Deprecated you should not use this, we will remove it in the future +func SetMaxOpenConns(aliasName string, maxOpenConns int) { + d := getDB(aliasName) + d.SetMaxOpenConns(maxOpenConns) +} + +// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name +func (d *DB) SetMaxIdleConns(maxIdleConns int) { + d.maxIdleConns = maxIdleConns + d.DB.SetMaxIdleConns(maxIdleConns) +} + +// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name +func (d *DB) SetMaxOpenConns(maxOpenConns int) { + d.maxOpenConns = maxOpenConns + d.DB.SetMaxOpenConns(maxOpenConns) +} + +func (d *DB) SetConnMaxLifetime(lifeTime time.Duration) { + d.connMaxLifeTime = lifeTime + d.DB.SetConnMaxLifetime(lifeTime) +} + +func (d *DB) SetConnMaxIdleTime(idleTime time.Duration) { + d.connMaxIdleTime = idleTime + d.DB.SetConnMaxIdleTime(idleTime) +} + +// AddDB add a aliasName for the drivename +func AddDB(aliasName, driverName string, db *sql.DB, params ...DBOption) error { + _, err := addDB(aliasName, driverName, db, params...) + return err +} + +// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. +func RegisterDataBase(aliasName, driverName, dataSource string, params ...DBOption) error { + var ( + err error + db *sql.DB + res *DB + ) + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) + goto end + } + + res, err = addDB(aliasName, driverName, db, params...) + if err != nil { + goto end + } + + res.dataSource = dataSource + +end: + if err != nil { + if db != nil { + _ = db.Close() + } + DebugLog.Println(err.Error()) + } + + return err +} + +func RegisterDB(aliasName, driverName, dataSource string, params ...DBOption) error { + var ( + err error + db *sql.DB + res *DB + ) + db, err = sql.Open(driverName, dataSource) + if err != nil { + err = fmt.Errorf("Register db `%s`, %s", aliasName, err.Error()) + goto end + } + + res, err = getORSetDB(aliasName, driverName, db, params...) + if err != nil { + goto end + } + + if res.dataSource != dataSource { + res.dataSource = dataSource + } + +end: + if err != nil { + if db != nil { + _ = db.Close() + } + DebugLog.Println(err.Error()) + } + + return err +} + +// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. +func RegisterDriver(driverName string, typ DriverType) error { + if t, ok := drivers[driverName]; !ok { + drivers[driverName] = typ + } else { + if t != typ { + return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) + } + } + return nil +} + +// SetDataBaseTZ Change the database default used timezone +func SetDataBaseTZ(dbName string, tz *time.Location) error { + if db, ok := dataBaseCache.get(dbName); ok { + db.tz = tz + } else { + return fmt.Errorf("DataBase DB name `%s` not registered", dbName) + } + return nil +} + +// GetDB Get *DB from registered database by db aDB name. +// Use "default" as DB name if you not Set. +func GetDB(dbNames ...string) (*DB, error) { + var name string + if len(dbNames) > 0 { + name = dbNames[0] + } else { + name = "default" + } + al, ok := dataBaseCache.get(name) + if ok { + return al, nil + } + return nil, fmt.Errorf("DataBase of DB name `%s` not found", name) +} + +type stmtDecorator struct { + wg sync.WaitGroup + stmt *sql.Stmt +} + +func (s *stmtDecorator) getStmt() *sql.Stmt { + return s.stmt +} + +// acquire will add one +// since this method will be used inside read lock scope, +// so we can not do more things here +// we should think about refactor this +func (s *stmtDecorator) acquire() { + s.wg.Add(1) +} + +func (s *stmtDecorator) release() { + s.wg.Done() +} + +// garbage recycle for stmt +func (s *stmtDecorator) destroy() { + go func() { + s.wg.Wait() + _ = s.stmt.Close() + }() +} + +func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator { + return &stmtDecorator{ + stmt: sqlStmt, + } +} + +func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) { + cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) { + value.(*stmtDecorator).destroy() + }) + if err != nil { + return nil, err + } + return cache, nil +} + +type DBOption func(d *DB) + +// MaxIdleConnections return a hint about MaxIdleConnections +func MaxIdleConnections(maxIdleConn int) DBOption { + return func(d *DB) { + d.SetMaxIdleConns(maxIdleConn) + } +} + +// MaxOpenConnections return a hint about MaxOpenConnections +func MaxOpenConnections(maxOpenConn int) DBOption { + return func(d *DB) { + d.SetMaxOpenConns(maxOpenConn) + } +} + +// ConnMaxLifetime return a hint about ConnMaxLifetime +func ConnMaxLifetime(v time.Duration) DBOption { + return func(d *DB) { + d.SetConnMaxLifetime(v) + } +} + +// ConnMaxIdletime return a hint about ConnMaxIdletime +func ConnMaxIdletime(v time.Duration) DBOption { + return func(d *DB) { + d.SetConnMaxIdleTime(v) + } +} + +// MaxStmtCacheSize return a hint about MaxStmtCacheSize +func MaxStmtCacheSize(v int) DBOption { + return func(d *DB) { + d.stmtCacheSize = v + } +} diff --git a/client/orm/db_alias_test.go b/client/orm/db_alias_test.go new file mode 100644 index 0000000000..a11000142f --- /dev/null +++ b/client/orm/db_alias_test.go @@ -0,0 +1,88 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRegisterDataBase(t *testing.T) { + err := RegisterDataBase("test-params", DBARGS.Driver, DBARGS.Source, + MaxIdleConnections(20), + MaxOpenConnections(300), + ConnMaxLifetime(time.Minute), + ConnMaxIdletime(time.Minute)) + assert.Nil(t, err) + + al := getDB("test-params") + assert.NotNil(t, al) + assert.Equal(t, al.maxIdleConns, 20) + assert.Equal(t, al.maxOpenConns, 300) + assert.Equal(t, al.connMaxLifeTime, time.Minute) + assert.Equal(t, al.connMaxIdleTime, time.Minute) +} + +func TestRegisterDataBaseMaxStmtCacheSizeNegative1(t *testing.T) { + dbName := "TestRegisterDataBase_MaxStmtCacheSizeNegative1" + err := RegisterDataBase(dbName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(-1)) + assert.Nil(t, err) + + al := getDB(dbName) + assert.NotNil(t, al) + assert.Equal(t, al.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBaseMaxStmtCacheSize0(t *testing.T) { + dbName := "TestRegisterDataBase_MaxStmtCacheSize0" + err := RegisterDataBase(dbName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(0)) + assert.Nil(t, err) + + al := getDB(dbName) + assert.NotNil(t, al) + assert.Equal(t, al.stmtDecoratorsLimit, 0) +} + +func TestRegisterDataBaseMaxStmtCacheSize1(t *testing.T) { + dbName := "TestRegisterDataBase_MaxStmtCacheSize1" + err := RegisterDataBase(dbName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(1)) + assert.Nil(t, err) + + al := getDB(dbName) + assert.NotNil(t, al) + assert.Equal(t, al.stmtDecoratorsLimit, 1) +} + +func TestRegisterDataBaseMaxStmtCacheSize841(t *testing.T) { + dbName := "TestRegisterDataBase_MaxStmtCacheSize841" + err := RegisterDataBase(dbName, DBARGS.Driver, DBARGS.Source, MaxStmtCacheSize(841)) + assert.Nil(t, err) + + al := getDB(dbName) + assert.NotNil(t, al) + assert.Equal(t, al.stmtDecoratorsLimit, 841) +} + +func TestDBCache(t *testing.T) { + dataBaseCache.add("test1", &DB{}) + dataBaseCache.add("default", &DB{}) + al := dataBaseCache.getDefault() + assert.NotNil(t, al) + al, ok := dataBaseCache.get("test1") + assert.NotNil(t, al) + assert.True(t, ok) +} diff --git a/orm/db_mysql.go b/client/orm/db_mysql.go similarity index 59% rename from orm/db_mysql.go rename to client/orm/db_mysql.go index 6e99058ec9..556b550d29 100644 --- a/orm/db_mysql.go +++ b/client/orm/db_mysql.go @@ -15,17 +15,21 @@ package orm import ( + "context" "fmt" "reflect" "strings" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // mysql operators. var mysqlOperators = map[string]string{ - "exact": "= ?", - "iexact": "LIKE ?", - "contains": "LIKE BINARY ?", - "icontains": "LIKE ?", + "exact": "= ?", + "iexact": "LIKE ?", + "strictexact": "= BINARY ?", + "contains": "LIKE BINARY ?", + "icontains": "LIKE ?", // "regex": "REGEXP BINARY ?", // "iregex": "REGEXP ?", "gt": "> ?", @@ -42,24 +46,25 @@ var mysqlOperators = map[string]string{ // mysql column field types. var mysqlTypes = map[string]string{ - "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "longtext", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", + "auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "longtext", + "time.Time-date": "date", + "time.Time": "datetime", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "time.Time-precision": "datetime(%d)", } // mysql dbBaser implementation. @@ -69,30 +74,30 @@ type dbBaseMysql struct { var _ dbBaser = new(dbBaseMysql) -// get mysql operator. +// OperatorSQL Get mysql operator. func (d *dbBaseMysql) OperatorSQL(operator string) string { return mysqlOperators[operator] } -// get mysql table field types. +// DbTypes Get mysql table field types. func (d *dbBaseMysql) DbTypes() map[string]string { return mysqlTypes } -// show table sql for mysql. +// ShowTablesQuery show table sql for mysql. func (d *dbBaseMysql) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" } -// show columns sql of table for mysql. +// ShowColumnsQuery show Columns sql of table for mysql. func (d *dbBaseMysql) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.Columns "+ "WHERE table_schema = DATABASE() AND table_name = '%s'", table) } -// execute sql to check index exist. -func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +// IndexExists execute sql to check index exist. +func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) @@ -103,13 +108,13 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool // If your primary key or unique column conflict will update // If no will insert // Add "`" for mysql sql building -func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { +func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, a *DB, args ...string) (int64, error) { var iouStr string argsMap := map[string]string{} iouStr = "ON DUPLICATE KEY UPDATE" - //Get on the key-value pairs + // Get on the key-value pairs for _, v := range args { kv := strings.Split(v, "=") if len(kv) == 2 { @@ -117,11 +122,9 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val } } - isMulti := false - names := make([]string, 0, len(mi.fields.dbcols)-1) + names := make([]string, 0, len(mi.Fields.DBcols)-1) Q := d.ins.TableQuote() - values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) - + values, _, err := d.collectValues(mi, ind, mi.Fields.DBcols, true, true, &names, a.tz) if err != nil { return 0, err } @@ -148,28 +151,26 @@ func (d *dbBaseMysql) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Val qupdates := strings.Join(updates, ", ") columns := strings.Join(names, sep) - multi := len(values) / len(names) - - if isMulti { - qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks - } - //conflitValue maybe is a int,can`t use fmt.Sprintf - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) + // conflitValue maybe is an int,can`t use fmt.Sprintf + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.Table, Q, Q, columns, Q, qmarks, iouStr) d.ins.ReplaceMarks(&query) - if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + if !d.ins.HasReturningID(mi, &query) { + res, err := q.ExecContext(ctx, query, values...) if err == nil { - if isMulti { - return res.RowsAffected() + lastInsertId, err := res.LastInsertId() + if err != nil { + DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + return lastInsertId, ErrLastInsertIdUnavailable + } else { + return lastInsertId, nil } - return res.LastInsertId() } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err = row.Scan(&id) return id, err diff --git a/orm/db_oracle.go b/client/orm/db_oracle.go similarity index 52% rename from orm/db_oracle.go rename to client/orm/db_oracle.go index 5d121f8342..8776f23ad7 100644 --- a/orm/db_oracle.go +++ b/client/orm/db_oracle.go @@ -15,8 +15,13 @@ package orm import ( + "context" "fmt" "strings" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/client/orm/hints" ) // oracle operators. @@ -31,23 +36,24 @@ var oracleOperators = map[string]string{ // oracle column field types. var oracleTypes = map[string]string{ - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "VARCHAR2(%d)", - "string-char": "CHAR(%d)", - "string-text": "VARCHAR2(%d)", - "time.Time-date": "DATE", - "time.Time": "TIMESTAMP", - "int8": "INTEGER", - "int16": "INTEGER", - "int32": "INTEGER", - "int64": "INTEGER", - "uint8": "INTEGER", - "uint16": "INTEGER", - "uint32": "INTEGER", - "uint64": "INTEGER", - "float64": "NUMBER", - "float64-decimal": "NUMBER(%d, %d)", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "VARCHAR2(%d)", + "string-char": "CHAR(%d)", + "string-text": "VARCHAR2(%d)", + "time.Time-date": "DATE", + "time.Time": "TIMESTAMP", + "int8": "INTEGER", + "int16": "INTEGER", + "int32": "INTEGER", + "int64": "INTEGER", + "uint8": "INTEGER", + "uint16": "INTEGER", + "uint32": "INTEGER", + "uint64": "INTEGER", + "float64": "NUMBER", + "float64-decimal": "NUMBER(%d, %d)", + "time.Time-precision": "TIMESTAMP(%d)", } // oracle dbBaser @@ -64,17 +70,17 @@ func newdbBaseOracle() dbBaser { return b } -// OperatorSQL get oracle operator. +// OperatorSQL Get oracle operator. func (d *dbBaseOracle) OperatorSQL(operator string) string { return oracleOperators[operator] } -// DbTypes get oracle table field types. +// DbTypes Get oracle table field types. func (d *dbBaseOracle) DbTypes() map[string]string { return oracleTypes } -//ShowTablesQuery show all the tables in database +// ShowTablesQuery show All the tables in database func (d *dbBaseOracle) ShowTablesQuery() string { return "SELECT TABLE_NAME FROM USER_TABLES" } @@ -86,8 +92,8 @@ func (d *dbBaseOracle) ShowColumnsQuery(table string) string { } // check index is exist -func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ +func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+ "WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+ "AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name)) @@ -96,9 +102,32 @@ func (d *dbBaseOracle) IndexExists(db dbQuerier, table string, name string) bool return cnt > 0 } -// execute insert sql with given struct and given values. +func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + var hint string + + switch useIndex { + case hints.KeyUseIndex, hints.KeyForceIndex: + hint = `INDEX` + case hints.KeyIgnoreIndex: + hint = `NO_INDEX` + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } + + return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`)) +} + +// InsertValue execute insert sql with given struct and given values. // insert the given values, not the field values in struct. -func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { +func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *models.ModelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { Q := d.ins.TableQuote() marks := make([]string, len(names)) @@ -116,21 +145,28 @@ func (d *dbBaseOracle) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, nam qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks } - query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) + query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.Table, Q, Q, columns, Q, qmarks) d.ins.ReplaceMarks(&query) if isMulti || !d.ins.HasReturningID(mi, &query) { - res, err := q.Exec(query, values...) + res, err := q.ExecContext(ctx, query, values...) if err == nil { if isMulti { return res.RowsAffected() } - return res.LastInsertId() + + lastInsertId, err := res.LastInsertId() + if err != nil { + DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) + return lastInsertId, ErrLastInsertIdUnavailable + } else { + return lastInsertId, nil + } } return 0, err } - row := q.QueryRow(query, values...) + row := q.QueryRowContext(ctx, query, values...) var id int64 err := row.Scan(&id) return id, err diff --git a/orm/db_postgres.go b/client/orm/db_postgres.go similarity index 59% rename from orm/db_postgres.go rename to client/orm/db_postgres.go index c488fb3889..bbe2eedd35 100644 --- a/orm/db_postgres.go +++ b/client/orm/db_postgres.go @@ -15,8 +15,11 @@ package orm import ( + "context" "fmt" "strconv" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // postgresql operators. @@ -39,26 +42,27 @@ var postgresOperators = map[string]string{ // postgresql column field types. var postgresTypes = map[string]string{ - "auto": "serial NOT NULL PRIMARY KEY", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "char(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "timestamp with time zone", - "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, - "uint16": `integer CHECK("%COL%" >= 0)`, - "uint32": `bigint CHECK("%COL%" >= 0)`, - "uint64": `bigint CHECK("%COL%" >= 0)`, - "float64": "double precision", - "float64-decimal": "numeric(%d, %d)", - "json": "json", - "jsonb": "jsonb", + "auto": "bigserial NOT NULL PRIMARY KEY", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "char(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "timestamp with time zone", + "int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`, + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`, + "uint16": `integer CHECK("%COL%" >= 0)`, + "uint32": `bigint CHECK("%COL%" >= 0)`, + "uint64": `bigint CHECK("%COL%" >= 0)`, + "float64": "double precision", + "float64-decimal": "numeric(%d, %d)", + "json": "json", + "jsonb": "jsonb", + "time.Time-precision": "timestamp(%d) with time zone", } // postgresql dbBaser. @@ -68,13 +72,13 @@ type dbBasePostgres struct { var _ dbBaser = new(dbBasePostgres) -// get postgresql operator. +// Get postgresql operator. func (d *dbBasePostgres) OperatorSQL(operator string) string { return postgresOperators[operator] } // generate functioned sql string, such as contains(text). -func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { +func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *models.FieldInfo, operator string, leftCol *string) { switch operator { case "contains", "startswith", "endswith": *leftCol = fmt.Sprintf("%s::text", *leftCol) @@ -83,7 +87,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, } } -// postgresql unsupports updating joined record. +// SupportUpdateJoin postgresql unsupports updating joined record. func (d *dbBasePostgres) SupportUpdateJoin() bool { return false } @@ -92,12 +96,12 @@ func (d *dbBasePostgres) MaxLimit() uint64 { return 0 } -// postgresql quote is ". +// TableQuote postgresql quote is ". func (d *dbBasePostgres) TableQuote() string { return `"` } -// postgresql value placeholder is $n. +// ReplaceMarks postgresql value placeholder is $n. // replace default ? to $n. func (d *dbBasePostgres) ReplaceMarks(query *string) { q := *query @@ -125,21 +129,20 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { *query = string(data) } -// make returning sql support for postgresql. -func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool { - fi := mi.fields.pk - if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 { +// HasReturningID make returning sql support for postgresql. +func (d *dbBasePostgres) HasReturningID(mi *models.ModelInfo, query *string) bool { + fi := mi.Fields.Pk + if fi.FieldType&IsPositiveIntegerField == 0 && fi.FieldType&IsIntegerField == 0 { return false } - if query != nil { - *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column) + *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.Column) } return true } // sync auto key -func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { +func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *models.ModelInfo, autoFields []string) error { if len(autoFields) == 0 { return nil } @@ -147,10 +150,10 @@ func (d *dbBasePostgres) setval(db dbQuerier, mi *modelInfo, autoFields []string Q := d.ins.TableQuote() for _, name := range autoFields { query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));", - mi.table, name, + mi.Table, name, Q, name, Q, - Q, mi.table, Q) - if _, err := db.Exec(query); err != nil { + Q, mi.Table, Q) + if _, err := db.ExecContext(ctx, query); err != nil { return err } } @@ -162,25 +165,31 @@ func (d *dbBasePostgres) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')" } -// show table columns sql for postgresql. +// show table Columns sql for postgresql. func (d *dbBasePostgres) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) + return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.Columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table) } -// get column types of postgresql. +// Get column types of postgresql. func (d *dbBasePostgres) DbTypes() map[string]string { return postgresTypes } // check index exist in postgresql. -func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name) - row := db.QueryRow(query) + row := db.QueryRowContext(ctx, query) var cnt int row.Scan(&cnt) return cnt > 0 } +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored") + return `` +} + // create new postgresql dbBaser. func newdbBasePostgres() dbBaser { b := new(dbBasePostgres) diff --git a/client/orm/db_postgres_test.go b/client/orm/db_postgres_test.go new file mode 100644 index 0000000000..c669cae3b9 --- /dev/null +++ b/client/orm/db_postgres_test.go @@ -0,0 +1,38 @@ +// Copyright 2023 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + + imodels "github.com/beego/beego/v2/client/orm/internal/models" +) + +func TestDbBasePostgres_HasReturningID(t *testing.T) { + base := newdbBasePostgres() + val := reflect.ValueOf(&StringID{}) + mi := imodels.NewModelInfo(val) + str := "" + ok := base.HasReturningID(mi, &str) + assert.False(t, ok) + assert.Equal(t, "", str) +} + +type StringID struct { + ID string `orm:"pk"` +} diff --git a/orm/db_sqlite.go b/client/orm/db_sqlite.go similarity index 52% rename from orm/db_sqlite.go rename to client/orm/db_sqlite.go index 0f54d81a45..4228259221 100644 --- a/orm/db_sqlite.go +++ b/client/orm/db_sqlite.go @@ -15,8 +15,16 @@ package orm import ( + "context" "database/sql" "fmt" + "reflect" + "strings" + "time" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/client/orm/hints" ) // sqlite operators. @@ -39,24 +47,25 @@ var sqliteOperators = map[string]string{ // sqlite column types. var sqliteTypes = map[string]string{ - "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", - "pk": "NOT NULL PRIMARY KEY", - "bool": "bool", - "string": "varchar(%d)", - "string-char": "character(%d)", - "string-text": "text", - "time.Time-date": "date", - "time.Time": "datetime", - "int8": "tinyint", - "int16": "smallint", - "int32": "integer", - "int64": "bigint", - "uint8": "tinyint unsigned", - "uint16": "smallint unsigned", - "uint32": "integer unsigned", - "uint64": "bigint unsigned", - "float64": "real", - "float64-decimal": "decimal", + "auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT", + "pk": "NOT NULL PRIMARY KEY", + "bool": "bool", + "string": "varchar(%d)", + "string-char": "character(%d)", + "string-text": "text", + "time.Time-date": "date", + "time.Time": "datetime", + "time.Time-precision": "datetime(%d)", + "int8": "tinyint", + "int16": "smallint", + "int32": "integer", + "int64": "bigint", + "uint8": "tinyint unsigned", + "uint16": "smallint unsigned", + "uint32": "integer unsigned", + "uint64": "bigint unsigned", + "float64": "real", + "float64-decimal": "decimal", } // sqlite dbBaser. @@ -66,15 +75,23 @@ type dbBaseSqlite struct { var _ dbBaser = new(dbBaseSqlite) -// get sqlite operator. +// override base db read for update behavior as SQlite does not support syntax +func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { + if isForUpdate { + DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work") + } + return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false) +} + +// Get sqlite operator. func (d *dbBaseSqlite) OperatorSQL(operator string) string { return sqliteOperators[operator] } // generate functioned sql for sqlite. // only support DATE(text). -func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) { - if fi.fieldType == TypeDateField { +func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *models.FieldInfo, operator string, leftCol *string) { + if fi.FieldType == TypeDateField { *leftCol = fmt.Sprintf("DATE(%s)", *leftCol) } } @@ -89,20 +106,20 @@ func (d *dbBaseSqlite) MaxLimit() uint64 { return 9223372036854775807 } -// get column types in sqlite. +// Get column types in sqlite. func (d *dbBaseSqlite) DbTypes() map[string]string { return sqliteTypes } -// get show tables sql in sqlite. +// Get show tables sql in sqlite. func (d *dbBaseSqlite) ShowTablesQuery() string { return "SELECT name FROM sqlite_master WHERE type = 'table'" } -// get columns in sqlite. -func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { +// Get Columns in sqlite. +func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { query := d.ins.ShowColumnsQuery(table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { return nil, err } @@ -117,18 +134,18 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str columns[name.String] = [3]string{name.String, typ.String, null.String} } - return columns, nil + return columns, rows.Err() } -// get show columns sql in sqlite. +// Get show Columns sql in sqlite. func (d *dbBaseSqlite) ShowColumnsQuery(table string) string { return fmt.Sprintf("pragma table_info('%s')", table) } // check index exist in sqlite. -func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool { +func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { query := fmt.Sprintf("PRAGMA index_list('%s')", table) - rows, err := db.Query(query) + rows, err := db.QueryContext(ctx, query) if err != nil { panic(err) } @@ -143,6 +160,24 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool return false } +// GenerateSpecifyIndex return a specifying index clause +func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { + var s []string + Q := d.TableQuote() + for _, index := range indexes { + tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) + s = append(s, tmp) + } + + switch useIndex { + case hints.KeyUseIndex, hints.KeyForceIndex: + return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`)) + default: + DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") + return `` + } +} + // create new sqlite dbBaser. func newdbBaseSqlite() dbBaser { b := new(dbBaseSqlite) diff --git a/orm/db_tables.go b/client/orm/db_tables.go similarity index 66% rename from orm/db_tables.go rename to client/orm/db_tables.go index 4b21a6fc72..c22f0a1108 100644 --- a/orm/db_tables.go +++ b/client/orm/db_tables.go @@ -18,6 +18,11 @@ import ( "fmt" "strings" "time" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/client/orm/clauses" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" ) // table info struct. @@ -28,8 +33,8 @@ type dbTable struct { names []string sel bool inner bool - mi *modelInfo - fi *fieldInfo + mi *models.ModelInfo + fi *models.FieldInfo jtl *dbTable } @@ -37,14 +42,14 @@ type dbTable struct { type dbTables struct { tablesM map[string]*dbTable tables []*dbTable - mi *modelInfo + mi *models.ModelInfo base dbBaser skipEnd bool } // set table info to collection. // if not exist, create new. -func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable { +func (t *dbTables) set(names []string, mi *models.ModelInfo, fi *models.FieldInfo, inner bool) *dbTable { name := strings.Join(names, ExprSep) if j, ok := t.tablesM[name]; ok { j.name = name @@ -61,7 +66,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) } // add table info to collection. -func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) { +func (t *dbTables) add(names []string, mi *models.ModelInfo, fi *models.FieldInfo, inner bool) (*dbTable, bool) { name := strings.Join(names, ExprSep) if _, ok := t.tablesM[name]; !ok { i := len(t.tables) + 1 @@ -79,31 +84,30 @@ func (t *dbTables) get(name string) (*dbTable, bool) { return j, ok } -// get related fields info in recursive depth loop. +// get related Fields info in recursive depth loop. // loop once, depth decreases one. -func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string { - if depth < 0 || fi.fieldType == RelManyToMany { +func (t *dbTables) loopDepth(depth int, prefix string, fi *models.FieldInfo, related []string) []string { + if depth < 0 || fi.FieldType == RelManyToMany { return related } if prefix == "" { - prefix = fi.name + prefix = fi.Name } else { - prefix = prefix + ExprSep + fi.name + prefix = prefix + ExprSep + fi.Name } related = append(related, prefix) depth-- - for _, fi := range fi.relModelInfo.fields.fieldsRel { + for _, fi := range fi.RelModelInfo.Fields.FieldsRel { related = t.loopDepth(depth, prefix, fi, related) } return related } -// parse related fields. +// parse related Fields. func (t *dbTables) parseRelated(rels []string, depth int) { - relsNum := len(rels) related := make([]string, relsNum) copy(related, rels) @@ -115,7 +119,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) { } relDepth-- - for _, fi := range t.mi.fields.fieldsRel { + for _, fi := range t.mi.Fields.FieldsRel { related = t.loopDepth(relDepth, "", fi, related) } @@ -131,18 +135,18 @@ func (t *dbTables) parseRelated(rels []string, depth int) { inner := true for _, ex := range exs { - if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { - names = append(names, fi.name) - mmi = fi.relModelInfo + if fi, ok := mmi.Fields.GetByAny(ex); ok && fi.Rel && fi.FieldType != RelManyToMany { + names = append(names, fi.Name) + mmi = fi.RelModelInfo - if fi.null || t.skipEnd { + if fi.Null || t.skipEnd { inner = false } jt := t.set(names, mmi, fi, inner) jt.jtl = jtl - if fi.reverse { + if fi.Reverse { cancel = false } @@ -183,24 +187,24 @@ func (t *dbTables) getJoinSQL() (join string) { t1 = jt.jtl.index } t2 = jt.index - table = jt.mi.table + table = jt.mi.Table switch { - case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: - c1 = jt.fi.mi.fields.pk.column - for _, ffi := range jt.mi.fields.fieldsRel { - if jt.fi.mi == ffi.relModelInfo { - c2 = ffi.column + case jt.fi.FieldType == RelManyToMany || jt.fi.FieldType == RelReverseMany || jt.fi.Reverse && jt.fi.ReverseFieldInfo.FieldType == RelManyToMany: + c1 = jt.fi.Mi.Fields.Pk.Column + for _, ffi := range jt.mi.Fields.FieldsRel { + if jt.fi.Mi == ffi.RelModelInfo { + c2 = ffi.Column break } } default: - c1 = jt.fi.column - c2 = jt.fi.relModelInfo.fields.pk.column + c1 = jt.fi.Column + c2 = jt.fi.RelModelInfo.Fields.Pk.Column - if jt.fi.reverse { - c1 = jt.mi.fields.pk.column - c2 = jt.fi.reverseFieldInfo.column + if jt.fi.Reverse { + c1 = jt.mi.Fields.Pk.Column + c2 = jt.fi.ReverseFieldInfo.Column } } @@ -211,11 +215,11 @@ func (t *dbTables) getJoinSQL() (join string) { } // parse orm model struct field tag expression. -func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { +func (t *dbTables) parseExprs(mi *models.ModelInfo, exprs []string) (index, name string, info *models.FieldInfo, success bool) { var ( jtl *dbTable - fi *fieldInfo - fiN *fieldInfo + fi *models.FieldInfo + fiN *models.FieldInfo mmi = mi ) @@ -236,38 +240,38 @@ loopFor: } if i == 0 { - fi, ok = mmi.fields.GetByAny(ex) + fi, ok = mmi.Fields.GetByAny(ex) } _ = okN if ok { - isRel := fi.rel || fi.reverse + isRel := fi.Rel || fi.Reverse - names = append(names, fi.name) + names = append(names, fi.Name) switch { - case fi.rel: - mmi = fi.relModelInfo - if fi.fieldType == RelManyToMany { - mmi = fi.relThroughModelInfo + case fi.Rel: + mmi = fi.RelModelInfo + if fi.FieldType == RelManyToMany { + mmi = fi.RelThroughModelInfo } - case fi.reverse: - mmi = fi.reverseFieldInfo.mi + case fi.Reverse: + mmi = fi.ReverseFieldInfo.Mi } if i < num { - fiN, okN = mmi.fields.GetByAny(exprs[i+1]) + fiN, okN = mmi.Fields.GetByAny(exprs[i+1]) } - if isRel && (!fi.mi.isThrough || num != i) { - if fi.null || t.skipEnd { + if isRel && (!fi.Mi.IsThrough || num != i) { + if fi.Null || t.skipEnd { inner = false } if t.skipEnd && okN || !t.skipEnd { - if t.skipEnd && okN && fiN.pk { + if t.skipEnd && okN && fiN.Pk { goto loopEnd } @@ -293,20 +297,20 @@ loopFor: info = fi if jtl == nil { - name = fi.name + name = fi.Name } else { - name = jtl.name + ExprSep + fi.name + name = jtl.name + ExprSep + fi.Name } switch { - case fi.rel: + case fi.Rel: - case fi.reverse: - switch fi.reverseFieldInfo.fieldType { + case fi.Reverse: + switch fi.ReverseFieldInfo.FieldType { case RelOneToOne, RelForeignKey: index = jtl.index - info = fi.reverseFieldInfo.mi.fields.pk - name = info.name + info = fi.ReverseFieldInfo.Mi.Fields.Pk + name = info.Name } } @@ -380,7 +384,7 @@ func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (whe operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz) } - leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) + leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.Column, Q) t.base.GenerateOperatorLeftCol(fi, operator, &leftCol) where += fmt.Sprintf("%s %s ", leftCol, operSQL) @@ -413,7 +417,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) } - groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)) + groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.Column, Q)) } groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", ")) @@ -421,7 +425,7 @@ func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) { } // generate order sql. -func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { +func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) { if len(orders) == 0 { return } @@ -430,19 +434,25 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { orderSqls := make([]string, 0, len(orders)) for _, order := range orders { - asc := "ASC" - if order[0] == '-' { - asc = "DESC" - order = order[1:] - } - exprs := strings.Split(order, ExprSep) + column := order.GetColumn() + clause := strings.Split(column, clauses.ExprDot) + + if order.IsRaw() { + if len(clause) == 2 { + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s %s", clause[0], clause[1], order.SortString())) + } else if len(clause) == 1 { + orderSqls = append(orderSqls, fmt.Sprintf("%s %s", clause[0], order.SortString())) + } else { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } + } else { + index, _, fi, suc := t.parseExprs(t.mi, clause) + if !suc { + panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep))) + } - index, _, fi, suc := t.parseExprs(t.mi, exprs) - if !suc { - panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) + orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.Column, Q, order.SortString())) } - - orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) } orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) @@ -450,7 +460,7 @@ func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) { } // generate limit sql. -func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) { +func (t *dbTables) getLimitSQL(mi *models.ModelInfo, offset int64, limit int64) (limits string) { if limit == 0 { limit = int64(DefaultRowsLimit) } @@ -472,8 +482,17 @@ func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits return } +// getIndexSql generate index sql. +func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string) (clause string) { + if len(indexes) == 0 { + return + } + + return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes) +} + // crete new tables collection. -func newDbTables(mi *modelInfo, base dbBaser) *dbTables { +func newDbTables(mi *models.ModelInfo, base dbBaser) *dbTables { tables := &dbTables{} tables.tablesM = make(map[string]*dbTable) tables.mi = mi diff --git a/client/orm/db_test.go b/client/orm/db_test.go new file mode 100644 index 0000000000..630e25f3e7 --- /dev/null +++ b/client/orm/db_test.go @@ -0,0 +1,1459 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/internal/buffers" + + "github.com/beego/beego/v2/client/orm/internal/models" +) + +func TestDbBase_InsertValueSQL(t *testing.T) { + mi := &models.ModelInfo{ + Table: "test_table", + } + + testCases := []struct { + name string + db *dbBase + isMulti bool + names []string + values []interface{} + + wantRes string + }{ + { + name: "single insert by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + isMulti: false, + names: []string{"name", "age"}, + values: []interface{}{"test", 18}, + wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)", + }, + { + name: "single insert by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + isMulti: false, + names: []string{"name", "age"}, + values: []interface{}{"test", 18}, + wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)", + }, + { + name: "multi insert by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + isMulti: true, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2", 19}, + wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?), (?, ?)", + }, + { + name: "multi insert by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + isMulti: true, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2", 19}, + wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2), ($3, $4)", + }, + { + name: "multi insert by dbBase but values is not enough", + db: &dbBase{ + ins: &dbBase{}, + }, + isMulti: true, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2"}, + wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)", + }, + { + name: "multi insert by dbBasePostgres but values is not enough", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + isMulti: true, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2"}, + wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)", + }, + { + name: "single insert by dbBase but values is double to names", + db: &dbBase{ + ins: &dbBase{}, + }, + isMulti: false, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2", 19}, + wantRes: "INSERT INTO `test_table` (`name`, `age`) VALUES (?, ?)", + }, + { + name: "single insert by dbBasePostgres but values is double to names", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + isMulti: false, + names: []string{"name", "age"}, + values: []interface{}{"test", 18, "test2", 19}, + wantRes: "INSERT INTO \"test_table\" (\"name\", \"age\") VALUES ($1, $2)", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res := tc.db.InsertValueSQL(tc.names, tc.values, tc.isMulti, mi) + + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestDbBase_UpdateSQL(t *testing.T) { + mi := &models.ModelInfo{ + Table: "test_table", + } + + testCases := []struct { + name string + db *dbBase + + setNames []string + pkName string + + wantRes string + }{ + { + name: "update by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + setNames: []string{"name", "age", "sender"}, + pkName: "id", + wantRes: "UPDATE `test_table` SET `name` = ?, `age` = ?, `sender` = ? WHERE `id` = ?", + }, + { + name: "update by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + setNames: []string{"name", "age", "sender"}, + pkName: "id", + wantRes: "UPDATE \"test_table\" SET \"name\" = $1, \"age\" = $2, \"sender\" = $3 WHERE \"id\" = $4", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res := tc.db.UpdateSQL(tc.setNames, tc.pkName, mi) + + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestDbBase_DeleteSQL(t *testing.T) { + mi := &models.ModelInfo{ + Table: "test_table", + } + + testCases := []struct { + name string + db *dbBase + + whereCols []string + + wantRes string + }{ + { + name: "delete by dbBase with id", + db: &dbBase{ + ins: &dbBase{}, + }, + whereCols: []string{"id"}, + wantRes: "DELETE FROM `test_table` WHERE `id` = ?", + }, + { + name: "delete by dbBase not id", + db: &dbBase{ + ins: &dbBase{}, + }, + whereCols: []string{"name", "age"}, + wantRes: "DELETE FROM `test_table` WHERE `name` = ? AND `age` = ?", + }, + { + name: "delete by dbBasePostgres with id", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + whereCols: []string{"id"}, + wantRes: "DELETE FROM \"test_table\" WHERE \"id\" = $1", + }, + { + name: "delete by dbBasePostgres not id", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + whereCols: []string{"name", "age"}, + wantRes: "DELETE FROM \"test_table\" WHERE \"name\" = $1 AND \"age\" = $2", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res := tc.db.DeleteSQL(tc.whereCols, mi) + + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestDbBase_buildSetSQL(t *testing.T) { + testCases := []struct { + name string + + db *dbBase + + columns []string + values []interface{} + + wantRes string + wantValues []interface{} + }{ + { + name: "Set add/mul operator by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColAdd, + value: 12, + }, + colValue{ + opt: ColMultiply, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: "SET T0.`name` = ?, T0.`age` = T0.`age` + ?, T0.`score` = T0.`score` * ?", + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set min/except operator by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColMinus, + value: 12, + }, + colValue{ + opt: ColExcept, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: "SET T0.`name` = ?, T0.`age` = T0.`age` - ?, T0.`score` = T0.`score` / ?", + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set bitRShift/bitLShift operator by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColBitRShift, + value: 12, + }, + colValue{ + opt: ColBitLShift, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: "SET T0.`name` = ?, T0.`age` = T0.`age` >> ?, T0.`score` = T0.`score` << ?", + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set bitAnd/bitOr/bitXOR operator by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + columns: []string{"count", "age", "score"}, + values: []interface{}{ + colValue{ + opt: ColBitAnd, + value: 28, + }, + colValue{ + opt: ColBitOr, + value: 12, + }, + colValue{ + opt: ColBitXOR, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: "SET T0.`count` = T0.`count` & ?, T0.`age` = T0.`age` | ?, T0.`score` = T0.`score` ^ ?", + wantValues: []interface{}{int64(28), int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set add/mul operator by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColAdd, + value: 12, + }, + colValue{ + opt: ColMultiply, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: `SET "name" = ?, "age" = "age" + ?, "score" = "score" * ?`, + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set min/except operator by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColMinus, + value: 12, + }, + colValue{ + opt: ColExcept, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: `SET "name" = ?, "age" = "age" - ?, "score" = "score" / ?`, + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set bitRShift/bitLShift operator by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColBitRShift, + value: 12, + }, + colValue{ + opt: ColBitLShift, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: `SET "name" = ?, "age" = "age" >> ?, "score" = "score" << ?`, + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set bitAnd/bitOr/bitXOR operator by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + columns: []string{"count", "age", "score"}, + values: []interface{}{ + colValue{ + opt: ColBitAnd, + value: 28, + }, + colValue{ + opt: ColBitOr, + value: 12, + }, + colValue{ + opt: ColBitXOR, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: `SET "count" = "count" & ?, "age" = "age" | ?, "score" = "score" ^ ?`, + wantValues: []interface{}{int64(28), int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set add/mul operator by dbBaseSqlite", + db: &dbBase{ + ins: newdbBaseSqlite(), + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColAdd, + value: 12, + }, + colValue{ + opt: ColMultiply, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: "SET `name` = ?, `age` = `age` + ?, `score` = `score` * ?", + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set min/except operator by dbBaseSqlite", + db: &dbBase{ + ins: newdbBaseSqlite(), + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColMinus, + value: 12, + }, + colValue{ + opt: ColExcept, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: "SET `name` = ?, `age` = `age` - ?, `score` = `score` / ?", + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set bitRShift/bitLShift operator by dbBaseSqlite", + db: &dbBase{ + ins: newdbBaseSqlite(), + }, + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColBitRShift, + value: 12, + }, + colValue{ + opt: ColBitLShift, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: "SET `name` = ?, `age` = `age` >> ?, `score` = `score` << ?", + wantValues: []interface{}{"test_name", int64(12), int64(2), "test_origin_name", 18}, + }, + { + name: "Set bitAnd/bitOr/bitXOR operator by dbBaseSqlite", + db: &dbBase{ + ins: newdbBaseSqlite(), + }, + columns: []string{"count", "age", "score"}, + values: []interface{}{ + colValue{ + opt: ColBitAnd, + value: 28, + }, + colValue{ + opt: ColBitOr, + value: 12, + }, + colValue{ + opt: ColBitXOR, + value: 2, + }, + "test_origin_name", + 18, + }, + wantRes: "SET `count` = `count` & ?, `age` = `age` | ?, `score` = `score` ^ ?", + wantValues: []interface{}{int64(28), int64(12), int64(2), "test_origin_name", 18}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + buf := buffers.Get() + defer buffers.Put(buf) + + tc.db.buildSetSQL(buf, tc.columns, tc.values) + + assert.Equal(t, tc.wantRes, buf.String()) + assert.Equal(t, tc.wantValues, tc.values) + }) + } +} + +func TestDbBase_UpdateBatchSQL(t *testing.T) { + registry := models.NewModelCacheHandler() + err := registry.Register("", false, new(testTab)) + + assert.Nil(t, err) + mi := &models.ModelInfo{ + Table: "test_tab", + Fields: &models.Fields{ + Pk: &models.FieldInfo{ + Column: "test_id", + }, + }, + } + + testCases := []struct { + name string + db *dbBase + + columns []string + values []interface{} + + specifyIndexes string + join string + where string + + wantRes string + }{ + { + name: "update batch by dbBase", + db: &dbBase{ + ins: &dbBase{}, + }, + + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColAdd, + value: 12, + }, + colValue{ + opt: ColMultiply, + value: 2, + }, + "test_origin_name", + 18, + }, + + specifyIndexes: " USE INDEX(`name`) ", + join: "LEFT OUTER JOIN `test_tab_2` T1 ON T1.`id` = T0.`test_id` ", + where: "WHERE T0.`name` = ? AND T1.`age` = ?", + + wantRes: "UPDATE `test_tab` T0 USE INDEX(`name`) LEFT OUTER JOIN `test_tab_2` T1 ON T1.`id` = T0.`test_id` SET T0.`name` = ?, T0.`age` = T0.`age` + ?, T0.`score` = T0.`score` * ? WHERE T0.`name` = ? AND T1.`age` = ?", + }, + { + name: "update batch by dbBasePostgres", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColAdd, + value: 12, + }, + colValue{ + opt: ColMultiply, + value: 2, + }, + "test_origin_name", + 18, + }, + + specifyIndexes: ` USE INDEX("name") `, + join: `LEFT OUTER JOIN "test_tab_2" T1 ON T1."id" = T0."test_id" `, + where: `WHERE T0."name" = ? AND T1."age" = ?`, + + wantRes: `UPDATE "test_tab" SET "name" = $1, "age" = "age" + $2, "score" = "score" * $3 WHERE "test_id" IN ( SELECT T0."test_id" FROM "test_tab" T0 USE INDEX("name") LEFT OUTER JOIN "test_tab_2" T1 ON T1."id" = T0."test_id" WHERE T0."name" = $4 AND T1."age" = $5 )`, + }, + { + name: "update batch by dbBaseSqlite", + db: &dbBase{ + ins: newdbBaseSqlite(), + }, + + columns: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + colValue{ + opt: ColAdd, + value: 12, + }, + colValue{ + opt: ColMultiply, + value: 2, + }, + "test_origin_name", + 18, + }, + + specifyIndexes: " USE INDEX(`name`) ", + join: "LEFT OUTER JOIN `test_tab_2` T1 ON T1.`id` = T0.`test_id` ", + where: "WHERE T0.`name` = ? AND T1.`age` = ?", + + wantRes: "UPDATE `test_tab` SET `name` = ?, `age` = `age` + ?, `score` = `score` * ? WHERE `test_id` IN ( SELECT T0.`test_id` FROM `test_tab` T0 USE INDEX(`name`) LEFT OUTER JOIN `test_tab_2` T1 ON T1.`id` = T0.`test_id` WHERE T0.`name` = ? AND T1.`age` = ? )", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res := tc.db.UpdateBatchSQL(mi, tc.columns, tc.values, tc.specifyIndexes, tc.join, tc.where) + + assert.Equal(t, tc.wantRes, res) + }) + } +} + +func TestDbBase_InsertOrUpdateSQL(t *testing.T) { + registry := models.NewModelCacheHandler() + err := registry.Register("", false, new(testTab)) + + assert.Nil(t, err) + mi := &models.ModelInfo{ + Table: "test_tab", + } + + testCases := []struct { + name string + db *dbBase + + names []string + values []interface{} + a *DB + args []string + + wantRes string + wantErr error + wantValues []interface{} + }{ + { + name: "test nonsupport driver", + db: &dbBase{ + ins: newdbBaseSqlite(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &DB{ + driver: DRSqlite, + driverName: "sqlite3", + }, + args: []string{ + "`age`=20", + "`score`=`score`+1", + }, + + wantErr: errors.New("`sqlite3` nonsupport InsertOrUpdate in beego"), + wantValues: []interface{}{ + "test_name", + 18, + 12, + }, + }, + { + name: "insert or update with MySQL", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &DB{ + driver: DRMySQL, + driverName: "mysql", + }, + args: []string{ + "`age`=20", + "`score`=`score`+1", + }, + + wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=20, `score`=`score`+1", + wantValues: []interface{}{ + "test_name", + 18, + 12, + "test_name", + }, + }, + { + name: "insert or update with MySQL with no args", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &DB{ + driver: DRMySQL, + driverName: "mysql", + }, + + wantRes: "INSERT INTO `test_tab` (`name`, `age`, `score`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `name`=?, `age`=?, `score`=?", + wantValues: []interface{}{ + "test_name", + 18, + 12, + "test_name", + 18, + 12, + }, + }, + { + name: "insert or update with PostgreSQL normal", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &DB{ + driver: DRPostgres, + driverName: "postgres", + }, + args: []string{ + `"name"`, + `"score"="score_1"`, + }, + + wantRes: `INSERT INTO "test_tab" ("name", "age", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "name"=$4, "age"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`, + wantValues: []interface{}{ + "test_name", + 18, + 12, + "test_name", + 18, + "test_name", + }, + }, + { + name: "insert or update with PostgreSQL without conflict column", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + names: []string{"name", "age", "score"}, + values: []interface{}{ + "test_name", + 18, + 12, + }, + a: &DB{ + driver: DRPostgres, + driverName: "postgres", + }, + + wantErr: errors.New("`postgres` use InsertOrUpdate must have a conflict column"), + wantValues: []interface{}{ + "test_name", + 18, + 12, + }, + }, + { + name: "insert or update with PostgreSQL the conflict column is not in front of the specified column", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + names: []string{"score", "name", "age"}, + values: []interface{}{ + 12, + "test_name", + 18, + }, + a: &DB{ + driver: DRPostgres, + driverName: "postgres", + }, + args: []string{ + `"name"`, + `"score"="score_1"`, + }, + + wantErr: errors.New("`\"name\"` must be in front of `\"score\"` in your struct"), + wantValues: []interface{}{ + 12, + "test_name", + 18, + }, + }, + { + name: "insert or update with PostgreSQL the conflict column is in front of the specified column", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + + names: []string{"age", "name", "score"}, + values: []interface{}{ + 18, + "test_name", + 12, + }, + a: &DB{ + driver: DRPostgres, + driverName: "postgres", + }, + args: []string{ + `"name"`, + `"score"="score_1"`, + }, + + wantRes: `INSERT INTO "test_tab" ("age", "name", "score") VALUES ($1, $2, $3) ON CONFLICT ("name") DO UPDATE SET "age"=$4, "name"=$5, "score"=(select "score_1" from test_tab where "name" = $6 )`, + wantValues: []interface{}{ + 18, + "test_name", + 12, + 18, + "test_name", + "test_name", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + res, err := tc.db.InsertOrUpdateSQL(tc.names, &tc.values, mi, tc.a, tc.args...) + + assert.Equal(t, tc.wantValues, tc.values) + + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + + assert.Equal(t, tc.wantRes, res) + }) + } + +} + +func TestDbBase_readBatchSQL(t *testing.T) { + + registry := models.NewModelCacheHandler() + err := registry.Register("", false, new(testTab), new(testTab1), new(testTab2)) + + assert.Nil(t, err) + + registry.Bootstrap() + + mi, ok := registry.GetByMd(new(testTab)) + + assert.True(t, ok) + + cond := NewCondition().And("name", "test_name"). + OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60)) + + tz := time.Local + + testCases := []struct { + name string + db *dbBase + + tCols []string + qs querySet + + wantRes string + wantArgs []interface{} + }{ + { + name: "read batch with MySQL", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: "SELECT T0.`name`, T0.`score`, T1.`id`, T1.`name_1`, T1.`age_1`, T1.`score_1`, T1.`test_tab_2_id`, T2.`id`, T2.`name_2`, T2.`age_2`, T2.`score_2` FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with MySQL and distinct", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + distinct: true, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: "SELECT DISTINCT T0.`name`, T0.`score`, T1.`id`, T1.`name_1`, T1.`age_1`, T1.`score_1`, T1.`test_tab_2_id`, T2.`id`, T2.`name_2`, T2.`age_2`, T2.`score_2` FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with MySQL and aggregate", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + aggregate: "sum(`T0`.`score`), count(`T1`.`name_1`)", + related: make([]string, 0), + relDepth: 2, + }, + wantRes: "SELECT sum(`T0`.`score`), count(`T1`.`name_1`) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with MySQL and distinct and aggregate", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + distinct: true, + aggregate: "sum(`T0`.`score`), count(`T1`.`name_1`)", + related: make([]string, 0), + relDepth: 2, + }, + wantRes: "SELECT DISTINCT sum(`T0`.`score`), count(`T1`.`name_1`) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with MySQL and for update", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + forUpdate: true, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: "SELECT T0.`name`, T0.`score`, T1.`id`, T1.`name_1`, T1.`age_1`, T1.`score_1`, T1.`test_tab_2_id`, T2.`id`, T2.`name_2`, T2.`age_2`, T2.`score_2` FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100 FOR UPDATE", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with PostgreSQL", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: `SELECT T0."name", T0."score", T1."id", T1."name_1", T1."age_1", T1."score_1", T1."test_tab_2_id", T2."id", T2."name_2", T2."age_2", T2."score_2" FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with PostgreSQL and distinct", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + distinct: true, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: `SELECT DISTINCT T0."name", T0."score", T1."id", T1."name_1", T1."age_1", T1."score_1", T1."test_tab_2_id", T2."id", T2."name_2", T2."age_2", T2."score_2" FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with PostgreSQL and aggregate", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + aggregate: `sum("T0"."score"), count("T1"."name_1")`, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: `SELECT sum("T0"."score"), count("T1"."name_1") FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with PostgreSQL and distinct and aggregate", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + distinct: true, + aggregate: `sum("T0"."score"), count("T1"."name_1")`, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: `SELECT DISTINCT sum("T0"."score"), count("T1"."name_1") FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read batch with PostgreSQL and for update", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + tCols: []string{"name", "score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + forUpdate: true, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: `SELECT T0."name", T0."score", T1."id", T1."name_1", T1."age_1", T1."score_1", T1."test_tab_2_id", T2."id", T2."name_2", T2."age_2", T2."score_2" FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100 FOR UPDATE`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tables := newDbTables(mi, tc.db.ins) + tables.parseRelated(tc.qs.related, tc.qs.relDepth) + + res, args := tc.db.readBatchSQL(tables, tc.tCols, cond, tc.qs, mi, tz) + + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantArgs, args) + }) + } + +} + +func TestDbBase_readValuesSQL(t *testing.T) { + registry := models.NewModelCacheHandler() + err := registry.Register("", false, new(testTab), new(testTab1), new(testTab2)) + + assert.Nil(t, err) + + registry.Bootstrap() + + mi, ok := registry.GetByMd(new(testTab)) + + assert.True(t, ok) + + cond := NewCondition().And("name", "test_name"). + OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60)) + + tz := time.Local + + testCases := []struct { + name string + db *dbBase + + cols []string + qs querySet + + wantRes string + wantArgs []interface{} + }{ + { + name: "read values with MySQL", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + cols: []string{"T0.`name` name", "T0.`age` age", "T0.`score` score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + }, + wantRes: "SELECT T0.`name` name, T0.`age` age, T0.`score` score FROM `test_tab` T0 USE INDEX(`name`,`score`) WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read values with MySQL and distinct", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + cols: []string{"T0.`name` name", "T0.`age` age", "T0.`score` score"}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + useIndex: 1, + indexes: []string{"name", "score"}, + distinct: true, + }, + wantRes: "SELECT DISTINCT T0.`name` name, T0.`age` age, T0.`score` score FROM `test_tab` T0 USE INDEX(`name`,`score`) WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ORDER BY T0.`score` DESC, T0.`age` ASC LIMIT 10 OFFSET 100", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read values with PostgreSQL", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + cols: []string{`T0."name" name`, `T0."age" age`, `T0."score" score`}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + }, + wantRes: `SELECT T0."name" name, T0."age" age, T0."score" score FROM "test_tab" T0 WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "read values with PostgreSQL and distinct", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + cols: []string{`T0."name" name`, `T0."age" age`, `T0."score" score`}, + qs: querySet{ + mi: mi, + cond: cond, + limit: 10, + offset: 100, + groups: []string{"name", "age"}, + orders: []*order_clause.Order{ + order_clause.Clause(order_clause.Column("score"), + order_clause.SortDescending()), + order_clause.Clause(order_clause.Column("age"), + order_clause.SortAscending()), + }, + distinct: true, + }, + wantRes: `SELECT DISTINCT T0."name" name, T0."age" age, T0."score" score FROM "test_tab" T0 WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ORDER BY T0."score" DESC, T0."age" ASC LIMIT 10 OFFSET 100`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tables := newDbTables(mi, tc.db.ins) + + res, args := tc.db.readValuesSQL(tables, tc.cols, tc.qs, mi, cond, tz) + + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantArgs, args) + }) + } + +} + +func TestDbBase_countSQL(t *testing.T) { + registry := models.NewModelCacheHandler() + err := registry.Register("", false, new(testTab), new(testTab1), new(testTab2)) + + assert.Nil(t, err) + + registry.Bootstrap() + + mi, ok := registry.GetByMd(new(testTab)) + + assert.True(t, ok) + + cond := NewCondition().And("name", "test_name"). + OrCond(NewCondition().And("age__gt", 18).And("score__lt", 60)) + + tz := time.Local + + testCases := []struct { + name string + db *dbBase + + qs querySet + + wantRes string + wantArgs []interface{} + }{ + { + name: "count with MySQL has no group by", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + qs: querySet{ + mi: mi, + cond: cond, + useIndex: 1, + indexes: []string{"name", "score"}, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: "SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) ", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "count with MySQL has group by", + db: &dbBase{ + ins: newdbBaseMysql(), + }, + qs: querySet{ + mi: mi, + cond: cond, + useIndex: 1, + indexes: []string{"name", "score"}, + related: make([]string, 0), + relDepth: 2, + groups: []string{"name", "age"}, + }, + wantRes: "SELECT COUNT(*) FROM (SELECT COUNT(*) FROM `test_tab` T0 USE INDEX(`name`,`score`) INNER JOIN `test_tab1` T1 ON T1.`id` = T0.`test_tab_1_id` INNER JOIN `test_tab2` T2 ON T2.`id` = T1.`test_tab_2_id` WHERE T0.`name` = ? OR ( T0.`age` > ? AND T0.`score` < ? ) GROUP BY T0.`name`, T0.`age` ) AS T", + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "count with PostgreSQL has no group by", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + qs: querySet{ + mi: mi, + cond: cond, + related: make([]string, 0), + relDepth: 2, + }, + wantRes: `SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) `, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + { + name: "count with PostgreSQL has group by", + db: &dbBase{ + ins: newdbBasePostgres(), + }, + qs: querySet{ + mi: mi, + cond: cond, + related: make([]string, 0), + relDepth: 2, + groups: []string{"name", "age"}, + }, + wantRes: `SELECT COUNT(*) FROM (SELECT COUNT(*) FROM "test_tab" T0 INNER JOIN "test_tab1" T1 ON T1."id" = T0."test_tab_1_id" INNER JOIN "test_tab2" T2 ON T2."id" = T1."test_tab_2_id" WHERE T0."name" = $1 OR ( T0."age" > $2 AND T0."score" < $3 ) GROUP BY T0."name", T0."age" ) AS T`, + wantArgs: []interface{}{"test_name", int64(18), int64(60)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + res, args := tc.db.countSQL(tc.qs, mi, cond, tz) + + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantArgs, args) + }) + } +} + +type testTab struct { + ID int64 `orm:"auto;pk;column(id)"` + Name string `orm:"column(name)"` + Age int64 `orm:"column(age)"` + Score int64 `orm:"column(score)"` + TestTab1 *testTab1 `orm:"rel(fk);column(test_tab_1_id)"` +} + +type testTab1 struct { + ID int64 `orm:"auto;pk;column(id)"` + Name1 string `orm:"column(name_1)"` + Age1 int64 `orm:"column(age_1)"` + Score1 int64 `orm:"column(score_1)"` + TestTab2 *testTab2 `orm:"rel(fk);column(test_tab_2_id)"` +} + +type testTab2 struct { + ID int64 `orm:"auto;pk;column(id)"` + Name2 int64 `orm:"column(name_2)"` + Age2 int64 `orm:"column(age_2)"` + Score2 int64 `orm:"column(score_2)"` +} diff --git a/orm/db_tidb.go b/client/orm/db_tidb.go similarity index 82% rename from orm/db_tidb.go rename to client/orm/db_tidb.go index 6020a488f5..863cf05ac1 100644 --- a/orm/db_tidb.go +++ b/client/orm/db_tidb.go @@ -15,6 +15,7 @@ package orm import ( + "context" "fmt" ) @@ -25,12 +26,12 @@ type dbBaseTidb struct { var _ dbBaser = new(dbBaseTidb) -// get mysql operator. +// Get mysql operator. func (d *dbBaseTidb) OperatorSQL(operator string) string { return mysqlOperators[operator] } -// get mysql table field types. +// Get mysql table field types. func (d *dbBaseTidb) DbTypes() map[string]string { return mysqlTypes } @@ -40,15 +41,15 @@ func (d *dbBaseTidb) ShowTablesQuery() string { return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()" } -// show columns sql of table for mysql. +// show Columns sql of table for mysql. func (d *dbBaseTidb) ShowColumnsQuery(table string) string { - return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+ + return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.Columns "+ "WHERE table_schema = DATABASE() AND table_name = '%s'", table) } // execute sql to check index exist. -func (d *dbBaseTidb) IndexExists(db dbQuerier, table string, name string) bool { - row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+ +func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool { + row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+ "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name) var cnt int row.Scan(&cnt) diff --git a/orm/db_utils.go b/client/orm/db_utils.go similarity index 61% rename from orm/db_utils.go rename to client/orm/db_utils.go index 7ae10ca5e4..ae7c6fe625 100644 --- a/orm/db_utils.go +++ b/client/orm/db_utils.go @@ -18,53 +18,55 @@ import ( "fmt" "reflect" "time" + + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" ) -// get table alias. -func getDbAlias(name string) *alias { +// Get table alias. +func getDB(name string) *DB { if al, ok := dataBaseCache.get(name); ok { return al } panic(fmt.Errorf("unknown DataBase alias name %s", name)) } -// get pk column info. -func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { - fi := mi.fields.pk +// Get pk column info. +func getExistPk(mi *models.ModelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { + fi := mi.Fields.Pk - v := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsPositiveIntegerField > 0 { + v := ind.FieldByIndex(fi.FieldIndex) + if fi.FieldType&IsPositiveIntegerField > 0 { vu := v.Uint() exist = vu > 0 value = vu - } else if fi.fieldType&IsIntegerField > 0 { + } else if fi.FieldType&IsIntegerField > 0 { vu := v.Int() exist = true value = vu - } else if fi.fieldType&IsRelField > 0 { - _, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v)) + } else if fi.FieldType&IsRelField > 0 { + _, value, exist = getExistPk(fi.RelModelInfo, reflect.Indirect(v)) } else { vu := v.String() exist = vu != "" value = vu } - column = fi.column + column = fi.Column return } -// get fields description as flatted string. -func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { - +// Get Fields description as flatted string. +func getFlatParams(fi *models.FieldInfo, args []interface{}, tz *time.Location) (params []interface{}) { outFor: for _, arg := range args { - val := reflect.ValueOf(arg) - if arg == nil { params = append(params, arg) continue } + val := reflect.ValueOf(arg) kind := val.Kind() if kind == reflect.Ptr { val = val.Elem() @@ -76,32 +78,32 @@ outFor: case reflect.String: v := val.String() if fi != nil { - if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField { + if fi.FieldType == TypeTimeField || fi.FieldType == TypeDateField || fi.FieldType == TypeDateTimeField { var t time.Time var err error if len(v) >= 19 { s := v[:19] - t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc) + t, err = time.ParseInLocation(utils.FormatDateTime, s, DefaultTimeLoc) } else if len(v) >= 10 { s := v if len(v) > 10 { s = v[:10] } - t, err = time.ParseInLocation(formatDate, s, tz) + t, err = time.ParseInLocation(utils.FormatDate, s, tz) } else { s := v if len(s) > 8 { s = v[:8] } - t, err = time.ParseInLocation(formatTime, s, tz) + t, err = time.ParseInLocation(utils.FormatTime, s, tz) } if err == nil { - if fi.fieldType == TypeDateField { - v = t.In(tz).Format(formatDate) - } else if fi.fieldType == TypeDateTimeField { - v = t.In(tz).Format(formatDateTime) + if fi.FieldType == TypeDateField { + v = t.In(tz).Format(utils.FormatDate) + } else if fi.FieldType == TypeDateTimeField { + v = t.In(tz).Format(utils.FormatDateTime) } else { - v = t.In(tz).Format(formatTime) + v = t.In(tz).Format(utils.FormatTime) } } } @@ -112,7 +114,7 @@ outFor: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: arg = val.Uint() case reflect.Float32: - arg, _ = StrTo(ToStr(arg)).Float64() + arg, _ = utils.StrTo(utils.ToStr(arg)).Float64() case reflect.Float64: arg = val.Float() case reflect.Bool: @@ -145,20 +147,20 @@ outFor: continue outFor case reflect.Struct: if v, ok := arg.(time.Time); ok { - if fi != nil && fi.fieldType == TypeDateField { - arg = v.In(tz).Format(formatDate) - } else if fi != nil && fi.fieldType == TypeDateTimeField { - arg = v.In(tz).Format(formatDateTime) - } else if fi != nil && fi.fieldType == TypeTimeField { - arg = v.In(tz).Format(formatTime) + if fi != nil && fi.FieldType == TypeDateField { + arg = v.In(tz).Format(utils.FormatDate) + } else if fi != nil && fi.FieldType == TypeDateTimeField { + arg = v.In(tz).Format(utils.FormatDateTime) + } else if fi != nil && fi.FieldType == TypeTimeField { + arg = v.In(tz).Format(utils.FormatTime) } else { - arg = v.In(tz).Format(formatDateTime) + arg = v.In(tz).Format(utils.FormatDateTime) } } else { typ := val.Type() - name := getFullName(typ) + name := models.GetFullName(typ) var value interface{} - if mmi, ok := modelCache.getByFullName(name); ok { + if mmi, ok := models.DefaultModelCache.GetByFullName(name); ok { if _, vu, exist := getExistPk(mmi, val); exist { value = vu } diff --git a/client/orm/ddl.go b/client/orm/ddl.go new file mode 100644 index 0000000000..3cbe69391e --- /dev/null +++ b/client/orm/ddl.go @@ -0,0 +1,195 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "errors" + "fmt" + "strings" + + imodels "github.com/beego/beego/v2/client/orm/internal/models" +) + +// getDbDropSQL Get database scheme drop sql queries +func getDbDropSQL(registry *imodels.ModelCache, al *DB) (queries []string, err error) { + if registry.Empty() { + err = errors.New("no Model found, need Register your model") + return + } + + Q := al.dbBaser.TableQuote() + + for _, mi := range registry.AllOrdered() { + queries = append(queries, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.Table, Q)) + } + return queries, nil +} + +// getDbCreateSQL Get database scheme creation sql queries +func getDbCreateSQL(registry *imodels.ModelCache, al *DB) (queries []string, tableIndexes map[string][]dbIndex, err error) { + if registry.Empty() { + err = errors.New("no Model found, need Register your model") + return + } + + Q := al.dbBaser.TableQuote() + T := al.dbBaser.DbTypes() + sep := fmt.Sprintf("%s, %s", Q, Q) + + tableIndexes = make(map[string][]dbIndex) + + for _, mi := range registry.AllOrdered() { + sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.FullName) + sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) + + sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.Table, Q) + + columns := make([]string, 0, len(mi.Fields.FieldsDB)) + + sqlIndexes := [][]string{} + var commentIndexes []int // store comment indexes for postgres + + for i, fi := range mi.Fields.FieldsDB { + column := fmt.Sprintf(" %s%s%s ", Q, fi.Column, Q) + col := getColumnTyp(al, fi) + if fi.DBType != "" { + column += fi.DBType + } else if fi.Auto { + switch al.driver { + case DRSqlite, DRPostgres: + column += T["auto"] + default: + column += col + " " + T["auto"] + } + } else if fi.Pk { + column += col + " " + T["pk"] + } else { + column += col + + if !fi.Null { + column += " " + "NOT NULL" + } + + // if fi.initial.String() != "" { + // column += " DEFAULT " + fi.initial.String() + // } + + // Append attribute DEFAULT + column += getColumnDefault(fi) + + if fi.Unique { + column += " " + "UNIQUE" + } + + if fi.Index { + sqlIndexes = append(sqlIndexes, []string{fi.Column}) + } + } + + if strings.Contains(column, "%COL%") { + column = strings.Replace(column, "%COL%", fi.Column, -1) + } + + if fi.Description != "" && al.driver != DRSqlite { + if al.driver == DRPostgres { + commentIndexes = append(commentIndexes, i) + } else { + column += " " + fmt.Sprintf("COMMENT '%s'", fi.Description) + } + } + + columns = append(columns, column) + } + + if mi.Model != nil { + allnames := imodels.GetTableUnique(mi.AddrField) + if !mi.Manual && len(mi.Uniques) > 0 { + allnames = append(allnames, mi.Uniques) + } + for _, names := range allnames { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol { + cols = append(cols, fi.Column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.FullName)) + } + } + column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) + columns = append(columns, column) + } + } + + sql += strings.Join(columns, ",\n") + sql += "\n)" + + if al.driver == DRMySQL { + var engine string + if mi.Model != nil { + engine = imodels.GetTableEngine(mi.AddrField) + } + if engine == "" { + engine = al.engine + } + sql += " ENGINE=" + engine + } + + sql += ";" + if al.driver == DRPostgres && len(commentIndexes) > 0 { + // append comments for postgres only + for _, index := range commentIndexes { + sql += fmt.Sprintf("\nCOMMENT ON COLUMN %s%s%s.%s%s%s is '%s';", + Q, + mi.Table, + Q, + Q, + mi.Fields.FieldsDB[index].Column, + Q, + mi.Fields.FieldsDB[index].Description) + } + } + queries = append(queries, sql) + + if mi.Model != nil { + for _, names := range imodels.GetTableIndex(mi.AddrField) { + cols := make([]string, 0, len(names)) + for _, name := range names { + if fi, ok := mi.Fields.GetByAny(name); ok && fi.DBcol { + cols = append(cols, fi.Column) + } else { + panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.FullName)) + } + } + sqlIndexes = append(sqlIndexes, cols) + } + } + + for _, names := range sqlIndexes { + name := mi.Table + "_" + strings.Join(names, "_") + cols := strings.Join(names, sep) + sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.Table, Q, Q, cols, Q) + + index := dbIndex{} + index.Table = mi.Table + index.Name = name + index.SQL = sql + + tableIndexes[mi.Table] = append(tableIndexes[mi.Table], index) + } + + } + return +} diff --git a/client/orm/ddl_test.go b/client/orm/ddl_test.go new file mode 100644 index 0000000000..95c018b648 --- /dev/null +++ b/client/orm/ddl_test.go @@ -0,0 +1,88 @@ +// Copyright 2022 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/stretchr/testify/assert" +) + +type ModelWithComments struct { + ID int `orm:"column(id);description(user id)"` + UserName string `orm:"size(30);unique;description(user name)"` + Email string `orm:"size(100);description(email)"` + Password string `orm:"size(100);description(password)"` +} + +type ModelWithoutComments struct { + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` +} + +type ModelWithEmptyComments struct { + ID int `orm:"column(id);description()"` + UserName string `orm:"size(30);unique;description()"` + Email string `orm:"size(100);description()"` + Password string `orm:"size(100);description()"` +} +type ModelWithDBTypes struct { + ID int `orm:"column(id);description();db_type(bigserial NOT NULL PRIMARY KEY)"` + UserName string `orm:"size(30);unique;description()"` + Email string `orm:"size(100);description()"` + Password string `orm:"size(100);description()"` +} + +func TestGetDbCreateSQLWithComment(t *testing.T) { + type TestCase struct { + name string + model interface{} + wantSQL string + wantErr error + } + al := getDB("default") + var testCases []TestCase + switch al.driver { + case DRMySQL: + testCases = append(testCases, TestCase{name: "model with comments for MySQL", model: &ModelWithComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS `model_with_comments` (\n `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY COMMENT 'user id',\n `user_name` varchar(30) NOT NULL DEFAULT '' UNIQUE COMMENT 'user name',\n `email` varchar(100) NOT NULL DEFAULT '' COMMENT 'email',\n `password` varchar(100) NOT NULL DEFAULT '' COMMENT 'password'\n) ENGINE=INNODB;", wantErr: nil}) + testCases = append(testCases, TestCase{name: "model without comments for MySQL", model: &ModelWithoutComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithoutComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS `model_without_comments` (\n `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n `user_name` varchar(30) NOT NULL DEFAULT '' UNIQUE,\n `email` varchar(100) NOT NULL DEFAULT '' ,\n `password` varchar(100) NOT NULL DEFAULT '' \n) ENGINE=INNODB;", wantErr: nil}) + testCases = append(testCases, TestCase{name: "model with empty comments for MySQL", model: &ModelWithEmptyComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithEmptyComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS `model_with_empty_comments` (\n `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n `user_name` varchar(30) NOT NULL DEFAULT '' UNIQUE,\n `email` varchar(100) NOT NULL DEFAULT '' ,\n `password` varchar(100) NOT NULL DEFAULT '' \n) ENGINE=INNODB;", wantErr: nil}) + testCases = append(testCases, TestCase{name: "model with dpType for MySQL", model: &ModelWithDBTypes{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithDBTypes`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS `model_with_d_b_types` (\n `id` bigserial NOT NULL PRIMARY KEY,\n `user_name` varchar(30) NOT NULL DEFAULT '' UNIQUE,\n `email` varchar(100) NOT NULL DEFAULT '' ,\n `password` varchar(100) NOT NULL DEFAULT '' \n) ENGINE=INNODB;", wantErr: nil}) + case DRPostgres: + testCases = append(testCases, TestCase{name: "model with comments for Postgres", model: &ModelWithComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS \"model_with_comments\" (\n \"id\" bigserial NOT NULL PRIMARY KEY,\n \"user_name\" varchar(30) NOT NULL DEFAULT '' UNIQUE,\n \"email\" varchar(100) NOT NULL DEFAULT '' ,\n \"password\" varchar(100) NOT NULL DEFAULT '' \n);\nCOMMENT ON COLUMN \"model_with_comments\".\"id\" is 'user id';\nCOMMENT ON COLUMN \"model_with_comments\".\"user_name\" is 'user name';\nCOMMENT ON COLUMN \"model_with_comments\".\"email\" is 'email';\nCOMMENT ON COLUMN \"model_with_comments\".\"password\" is 'password';", wantErr: nil}) + testCases = append(testCases, TestCase{name: "model without comments for Postgres", model: &ModelWithoutComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithoutComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS \"model_without_comments\" (\n \"id\" bigserial NOT NULL PRIMARY KEY,\n \"user_name\" varchar(30) NOT NULL DEFAULT '' UNIQUE,\n \"email\" varchar(100) NOT NULL DEFAULT '' ,\n \"password\" varchar(100) NOT NULL DEFAULT '' \n);", wantErr: nil}) + testCases = append(testCases, TestCase{name: "model with empty comments for Postgres", model: &ModelWithEmptyComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithEmptyComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS \"model_with_empty_comments\" (\n \"id\" bigserial NOT NULL PRIMARY KEY,\n \"user_name\" varchar(30) NOT NULL DEFAULT '' UNIQUE,\n \"email\" varchar(100) NOT NULL DEFAULT '' ,\n \"password\" varchar(100) NOT NULL DEFAULT '' \n);", wantErr: nil}) + case DRSqlite: + testCases = append(testCases, TestCase{name: "model with comments for Sqlite", model: &ModelWithComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS `model_with_comments` (\n `id` integer NOT NULL PRIMARY KEY AUTOINCREMENT,\n `user_name` varchar(30) NOT NULL DEFAULT '' UNIQUE,\n `email` varchar(100) NOT NULL DEFAULT '' ,\n `password` varchar(100) NOT NULL DEFAULT '' \n);", wantErr: nil}) + testCases = append(testCases, TestCase{name: "model without comments for Sqlite", model: &ModelWithoutComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithoutComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS `model_without_comments` (\n `id` integer NOT NULL PRIMARY KEY AUTOINCREMENT,\n `user_name` varchar(30) NOT NULL DEFAULT '' UNIQUE,\n `email` varchar(100) NOT NULL DEFAULT '' ,\n `password` varchar(100) NOT NULL DEFAULT '' \n);", wantErr: nil}) + testCases = append(testCases, TestCase{name: "model with empty comments for Sqlite", model: &ModelWithEmptyComments{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithEmptyComments`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS `model_with_empty_comments` (\n `id` integer NOT NULL PRIMARY KEY AUTOINCREMENT,\n `user_name` varchar(30) NOT NULL DEFAULT '' UNIQUE,\n `email` varchar(100) NOT NULL DEFAULT '' ,\n `password` varchar(100) NOT NULL DEFAULT '' \n);", wantErr: nil}) + testCases = append(testCases, TestCase{name: "model with dpType for Sqlite", model: &ModelWithDBTypes{}, wantSQL: "-- --------------------------------------------------\n-- Table Structure for `github.com/beego/beego/v2/client/orm.ModelWithDBTypes`\n-- --------------------------------------------------\nCREATE TABLE IF NOT EXISTS `model_with_d_b_types` (\n `id` bigserial NOT NULL PRIMARY KEY,\n `user_name` varchar(30) NOT NULL DEFAULT '' UNIQUE,\n `email` varchar(100) NOT NULL DEFAULT '' ,\n `password` varchar(100) NOT NULL DEFAULT '' \n);", wantErr: nil}) + + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + registry := models.NewModelCacheHandler() + err := registry.Register("", true, tc.model) + assert.NoError(t, err) + queries, _, err := getDbCreateSQL(registry, al) + assert.Equal(t, tc.wantSQL, queries[0]) + assert.Equal(t, tc.wantErr, err) + }) + } +} diff --git a/client/orm/do_nothing_orm.go b/client/orm/do_nothing_orm.go new file mode 100644 index 0000000000..4baa8366f3 --- /dev/null +++ b/client/orm/do_nothing_orm.go @@ -0,0 +1,191 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + + "github.com/beego/beego/v2/core/utils" +) + +// DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation +// I think golang mocking interface is hard to use +// this may help you to integrate with Ormer + +var _ Ormer = new(DoNothingOrm) + +type DoNothingOrm struct{} + +func (d *DoNothingOrm) Read(md interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingOrm) ReadRaw(ctx context.Context, md interface{}, query string, args ...any) error { + return nil +} + +func (d *DoNothingOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingOrm) ReadForUpdate(md interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingOrm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + return false, 0, nil +} + +func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { + return false, 0, nil +} + +func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer { + return nil +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer { + return nil +} + +func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { + return nil +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter { + return nil +} + +func (d *DoNothingOrm) DBStats() *sql.DBStats { + return nil +} + +func (d *DoNothingOrm) Insert(md interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertMulti(bulk int, mds interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) Update(md interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) Delete(md interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingOrm) Raw(query string, args ...interface{}) RawSeter { + return nil +} + +func (d *DoNothingOrm) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { + return nil +} + +func (*DoNothingOrm) ExecRaw(_ context.Context, _ interface{}, _ string, _ ...any) (sql.Result, error) { + return nil, nil +} + +func (d *DoNothingOrm) Driver() Driver { + return nil +} + +func (d *DoNothingOrm) Begin() (TxOrmer, error) { + return nil, nil +} + +func (d *DoNothingOrm) BeginWithCtx(ctx context.Context) (TxOrmer, error) { + return nil, nil +} + +func (d *DoNothingOrm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { + return nil, nil +} + +func (d *DoNothingOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + return nil, nil +} + +func (d *DoNothingOrm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { + return nil +} + +func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { + return nil +} + +func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { + return nil +} + +func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { + return nil +} + +var _ Ormer = new(DoNothingTxOrm) + +// DoNothingTxOrm is similar with DoNothingOrm, usually you use it to test +type DoNothingTxOrm struct { + DoNothingOrm +} + +func (d *DoNothingTxOrm) Commit() error { + return nil +} + +func (d *DoNothingTxOrm) Rollback() error { + return nil +} diff --git a/client/orm/do_nothing_orm_test.go b/client/orm/do_nothing_orm_test.go new file mode 100644 index 0000000000..e10f70af05 --- /dev/null +++ b/client/orm/do_nothing_orm_test.go @@ -0,0 +1,132 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingOrm(t *testing.T) { + o := &DoNothingOrm{} + err := o.DoTxWithCtxAndOpts(nil, nil, nil) + assert.Nil(t, err) + + err = o.DoTxWithCtx(nil, nil) + assert.Nil(t, err) + + err = o.DoTx(nil) + assert.Nil(t, err) + + err = o.DoTxWithOpts(nil, nil) + assert.Nil(t, err) + + assert.Nil(t, o.Driver()) + + assert.Nil(t, o.QueryM2M(nil, "")) + assert.Nil(t, o.ReadWithCtx(nil, nil)) + assert.Nil(t, o.Read(nil)) + + txOrm, err := o.BeginWithCtxAndOpts(nil, nil) + assert.Nil(t, err) + assert.Nil(t, txOrm) + + txOrm, err = o.BeginWithCtx(nil) + assert.Nil(t, err) + assert.Nil(t, txOrm) + + txOrm, err = o.BeginWithOpts(nil) + assert.Nil(t, err) + assert.Nil(t, txOrm) + + txOrm, err = o.Begin() + assert.Nil(t, err) + assert.Nil(t, txOrm) + + assert.Nil(t, o.RawWithCtx(nil, "")) + assert.Nil(t, o.Raw("")) + + i, err := o.InsertMulti(0, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.Insert(nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.InsertWithCtx(nil, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.InsertOrUpdateWithCtx(nil, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.InsertOrUpdate(nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.InsertMultiWithCtx(nil, 0, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.LoadRelatedWithCtx(nil, nil, "") + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.LoadRelated(nil, "") + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + assert.Nil(t, o.QueryTable(nil)) + + assert.Nil(t, o.Read(nil)) + assert.Nil(t, o.ReadWithCtx(nil, nil)) + assert.Nil(t, o.ReadForUpdateWithCtx(nil, nil)) + assert.Nil(t, o.ReadForUpdate(nil)) + + ok, i, err := o.ReadOrCreate(nil, "") + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + assert.False(t, ok) + + ok, i, err = o.ReadOrCreateWithCtx(nil, nil, "") + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + assert.False(t, ok) + + i, err = o.Delete(nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.DeleteWithCtx(nil, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.Update(nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + i, err = o.UpdateWithCtx(nil, nil) + assert.Nil(t, err) + assert.Equal(t, int64(0), i) + + assert.Nil(t, o.DBStats()) + + to := &DoNothingTxOrm{} + assert.Nil(t, to.Commit()) + assert.Nil(t, to.Rollback()) +} diff --git a/client/orm/filter.go b/client/orm/filter.go new file mode 100644 index 0000000000..cd313761c3 --- /dev/null +++ b/client/orm/filter.go @@ -0,0 +1,40 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" +) + +// FilterChain is used to build a Filter +// don't forget to call next(...) inside your Filter +type FilterChain func(next Filter) Filter + +// Filter behavior is a little big strange. +// it's only be called when users call methods of Ormer +// return value is an array. it's a little bit hard to understand, +// for example, the Ormer's Read method only return error +// so the filter processing this method should return an array whose first element is error +// and, Ormer's ReadOrCreateWithCtx return three values, so the Filter's result should contain three values +type Filter func(ctx context.Context, inv *Invocation) []interface{} + +var globalFilterChains = make([]FilterChain, 0, 4) + +// AddGlobalFilterChain adds a new FilterChain +// All orm instances built after this invocation will use this filterChain, +// but instances built before this invocation will not be affected +func AddGlobalFilterChain(filterChain ...FilterChain) { + globalFilterChains = append(globalFilterChains, filterChain...) +} diff --git a/client/orm/filter/bean/default_value_filter.go b/client/orm/filter/bean/default_value_filter.go new file mode 100644 index 0000000000..5d1fc2a168 --- /dev/null +++ b/client/orm/filter/bean/default_value_filter.go @@ -0,0 +1,133 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "reflect" + "strings" + + "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/core/bean" + "github.com/beego/beego/v2/core/logs" +) + +// DefaultValueFilterChainBuilder only works for InsertXXX method, +// But InsertOrUpdate and InsertOrUpdateWithCtx is more dangerous than other methods. +// so we won't handle those two methods unless you set includeInsertOrUpdate to true +// And if the element is not pointer, this filter doesn't work +type DefaultValueFilterChainBuilder struct { + factory bean.AutoWireBeanFactory + compatibleWithOldStyle bool + + // only the includeInsertOrUpdate is true, this filter will handle those two methods + includeInsertOrUpdate bool +} + +// NewDefaultValueFilterChainBuilder will create an instance of DefaultValueFilterChainBuilder +// In beego v1.x, the default value config looks like orm:default(xxxx) +// But the default value in 2.x is default:xxx +// so if you want to be compatible with v1.x, please pass true as compatibleWithOldStyle +func NewDefaultValueFilterChainBuilder(typeAdapters map[string]bean.TypeAdapter, + includeInsertOrUpdate bool, + compatibleWithOldStyle bool) *DefaultValueFilterChainBuilder { + factory := bean.NewTagAutoWireBeanFactory() + + if compatibleWithOldStyle { + newParser := factory.FieldTagParser + factory.FieldTagParser = func(field reflect.StructField) *bean.FieldMetadata { + if newParser != nil && field.Tag.Get(bean.DefaultValueTagKey) != "" { + return newParser(field) + } else { + res := &bean.FieldMetadata{} + ormMeta := field.Tag.Get("orm") + ormMetaParts := strings.Split(ormMeta, ";") + for _, p := range ormMetaParts { + if strings.HasPrefix(p, "default(") && strings.HasSuffix(p, ")") { + res.DftValue = p[8 : len(p)-1] + } + } + return res + } + } + } + + for k, v := range typeAdapters { + factory.Adapters[k] = v + } + + return &DefaultValueFilterChainBuilder{ + factory: factory, + compatibleWithOldStyle: compatibleWithOldStyle, + includeInsertOrUpdate: includeInsertOrUpdate, + } +} + +func (d *DefaultValueFilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { + switch inv.Method { + case "Insert", "InsertWithCtx": + d.handleInsert(ctx, inv) + case "InsertOrUpdate", "InsertOrUpdateWithCtx": + d.handleInsertOrUpdate(ctx, inv) + case "InsertMulti", "InsertMultiWithCtx": + d.handleInsertMulti(ctx, inv) + } + return next(ctx, inv) + } +} + +func (d *DefaultValueFilterChainBuilder) handleInsert(ctx context.Context, inv *orm.Invocation) { + d.setDefaultValue(ctx, inv.Args[0]) +} + +func (d *DefaultValueFilterChainBuilder) handleInsertOrUpdate(ctx context.Context, inv *orm.Invocation) { + if d.includeInsertOrUpdate { + ins := inv.Args[0] + if ins == nil { + return + } + + pkName := inv.GetPkFieldName() + pkField := reflect.Indirect(reflect.ValueOf(ins)).FieldByName(pkName) + + if pkField.IsZero() { + d.setDefaultValue(ctx, ins) + } + } +} + +func (d *DefaultValueFilterChainBuilder) handleInsertMulti(ctx context.Context, inv *orm.Invocation) { + mds := inv.Args[1] + + if t := reflect.TypeOf(mds).Kind(); t != reflect.Array && t != reflect.Slice { + // do nothing + return + } + + mdsArr := reflect.Indirect(reflect.ValueOf(mds)) + for i := 0; i < mdsArr.Len(); i++ { + d.setDefaultValue(ctx, mdsArr.Index(i).Interface()) + } + logs.Warn("%v", mdsArr.Index(0).Interface()) +} + +func (d *DefaultValueFilterChainBuilder) setDefaultValue(ctx context.Context, ins interface{}) { + err := d.factory.AutoWire(ctx, nil, ins) + if err != nil { + logs.Error("try to wire the bean for orm.Insert failed. "+ + "the default value is not set: %v, ", err) + } +} diff --git a/client/orm/filter/bean/default_value_filter_test.go b/client/orm/filter/bean/default_value_filter_test.go new file mode 100644 index 0000000000..ec47688d29 --- /dev/null +++ b/client/orm/filter/bean/default_value_filter_test.go @@ -0,0 +1,71 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestDefaultValueFilterChainBuilderFilterChain(t *testing.T) { + builder := NewDefaultValueFilterChainBuilder(nil, true, true) + o := orm.NewFilterOrmDecorator(&defaultValueTestOrm{}, builder.FilterChain) + + // test insert + entity := &DefaultValueTestEntity{} + _, _ = o.Insert(entity) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) + assert.Equal(t, 0, entity.AgeIgnore) + + // test InsertOrUpdate + entity = &DefaultValueTestEntity{} + orm.RegisterModel(entity) + + _, _ = o.InsertOrUpdate(entity) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) + + // we won't set the default value because we find the pk is not Zero value + entity.Id = 3 + entity.AgeInOldStyle = 0 + _, _ = o.InsertOrUpdate(entity) + assert.Equal(t, 0, entity.AgeInOldStyle) + + entity = &DefaultValueTestEntity{} + + // the entity is not array, it will be ignored + _, _ = o.InsertMulti(3, entity) + assert.Equal(t, 0, entity.Age) + assert.Equal(t, 0, entity.AgeInOldStyle) + + _, _ = o.InsertMulti(3, []*DefaultValueTestEntity{entity}) + assert.Equal(t, 12, entity.Age) + assert.Equal(t, 13, entity.AgeInOldStyle) +} + +type defaultValueTestOrm struct { + orm.DoNothingOrm +} + +type DefaultValueTestEntity struct { + Id int + Age int `default:"12"` + AgeInOldStyle int `orm:"default(13);bee()"` + AgeIgnore int +} diff --git a/client/orm/filter/opentelemetry/filter.go b/client/orm/filter/opentelemetry/filter.go new file mode 100644 index 0000000000..d7740ac40e --- /dev/null +++ b/client/orm/filter/opentelemetry/filter.go @@ -0,0 +1,90 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentelemetry + +import ( + "context" + "strings" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + otelTrace "go.opentelemetry.io/otel/trace" + + "github.com/beego/beego/v2/client/orm" +) + +type ( + CustomSpanFunc func(ctx context.Context, span otelTrace.Span, inv *orm.Invocation) + FilterChainOption func(fcv *FilterChainBuilder) +) + +// FilterChainBuilder provides an opentelemtry Filter +type FilterChainBuilder struct { + // customSpanFunc users are able to custom their span + customSpanFunc CustomSpanFunc +} + +func NewFilterChainBuilder(options ...FilterChainOption) *FilterChainBuilder { + fcb := &FilterChainBuilder{} + for _, o := range options { + o(fcb) + } + return fcb +} + +// WithCustomSpanFunc add function to custom span +func WithCustomSpanFunc(customSpanFunc CustomSpanFunc) FilterChainOption { + return func(fcv *FilterChainBuilder) { + fcv.customSpanFunc = customSpanFunc + } +} + +// FilterChain traces invocation with opentelemetry +// Unless invocation.Method is Begin*, Commit or Rollback +func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { + if strings.HasPrefix(inv.Method, "Begin") || inv.Method == "Commit" || inv.Method == "Rollback" { + return next(ctx, inv) + } + spanCtx, span := otel.Tracer("beego_orm").Start(ctx, invOperationName(ctx, inv)) + defer span.End() + + res := next(spanCtx, inv) + builder.buildSpan(spanCtx, span, inv) + return res + } +} + +// buildSpan add default span attributes and custom attributes with customSpanFunc +func (builder *FilterChainBuilder) buildSpan(ctx context.Context, span otelTrace.Span, inv *orm.Invocation) { + span.SetAttributes(attribute.String("orm.method", inv.Method)) + span.SetAttributes(attribute.String("orm.table", inv.GetTableName())) + span.SetAttributes(attribute.Bool("orm.insideTx", inv.InsideTx)) + v, _ := ctx.Value(orm.TxNameKey).(string) + span.SetAttributes(attribute.String("orm.txName", v)) + span.SetAttributes(attribute.String("span.kind", "client")) + span.SetAttributes(attribute.String("component", "beego")) + + if builder.customSpanFunc != nil { + builder.customSpanFunc(ctx, span, inv) + } +} + +func invOperationName(ctx context.Context, inv *orm.Invocation) string { + if n, ok := ctx.Value(orm.TxNameKey).(string); ok { + return inv.Method + "#tx(" + n + ")" + } + return inv.Method + "#" + inv.GetTableName() +} diff --git a/client/orm/filter/opentelemetry/filter_test.go b/client/orm/filter/opentelemetry/filter_test.go new file mode 100644 index 0000000000..589cb320c6 --- /dev/null +++ b/client/orm/filter/opentelemetry/filter_test.go @@ -0,0 +1,61 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentelemetry + +import ( + "bytes" + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" + "go.opentelemetry.io/otel/sdk/trace" + otelTrace "go.opentelemetry.io/otel/trace" + + "github.com/beego/beego/v2/client/orm" +) + +func TestFilterChainBuilderFilterChain(t *testing.T) { + // Init Trace + buf := bytes.NewBuffer([]byte{}) + exp, err := stdouttrace.New(stdouttrace.WithWriter(buf)) + if err != nil { + t.Error(err) + } + tp := trace.NewTracerProvider(trace.WithBatcher(exp)) + otel.SetTracerProvider(tp) + + // Build FilterChain + csf := func(ctx context.Context, span otelTrace.Span, inv *orm.Invocation) { + span.SetAttributes(attribute.String("hello", "work")) + } + builder := NewFilterChainBuilder(WithCustomSpanFunc(csf)) + + inv := &orm.Invocation{Method: "Hello"} + next := func(ctx context.Context, inv *orm.Invocation) []interface{} { return nil } + + builder.FilterChain(next)(context.Background(), inv) + + // Close tp + err = tp.Shutdown(context.Background()) + if err != nil { + t.Error(err) + } + + // Assert opentelemetry span name + assert.Equal(t, "Hello#", string(buf.Bytes()[9:15])) +} diff --git a/client/orm/filter/opentracing/filter.go b/client/orm/filter/opentracing/filter.go new file mode 100644 index 0000000000..7afa07f189 --- /dev/null +++ b/client/orm/filter/opentracing/filter.go @@ -0,0 +1,71 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + "strings" + + "github.com/opentracing/opentracing-go" + + "github.com/beego/beego/v2/client/orm" +) + +// FilterChainBuilder provides an extension point +// this Filter's behavior looks a little bit strange +// for example: +// if we want to trace QuerySetter +// actually we trace invoking "QueryTable" +// the method Begin*, Commit and Rollback are ignored. +// When use using those methods, it means that they want to manager their transaction manually, so we won't handle them. +type FilterChainBuilder struct { + // CustomSpanFunc users are able to custom their span + CustomSpanFunc func(span opentracing.Span, ctx context.Context, inv *orm.Invocation) +} + +func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { + operationName := builder.operationName(ctx, inv) + if strings.HasPrefix(inv.Method, "Begin") || inv.Method == "Commit" || inv.Method == "Rollback" { + return next(ctx, inv) + } + + span, spanCtx := opentracing.StartSpanFromContext(ctx, operationName) + defer span.Finish() + res := next(spanCtx, inv) + builder.buildSpan(span, spanCtx, inv) + return res + } +} + +func (builder *FilterChainBuilder) buildSpan(span opentracing.Span, ctx context.Context, inv *orm.Invocation) { + span.SetTag("orm.method", inv.Method) + span.SetTag("orm.table", inv.GetTableName()) + span.SetTag("orm.insideTx", inv.InsideTx) + span.SetTag("orm.txName", ctx.Value(orm.TxNameKey)) + span.SetTag("span.kind", "client") + span.SetTag("component", "beego") + + if builder.CustomSpanFunc != nil { + builder.CustomSpanFunc(span, ctx, inv) + } +} + +func (builder *FilterChainBuilder) operationName(ctx context.Context, inv *orm.Invocation) string { + if n, ok := ctx.Value(orm.TxNameKey).(string); ok { + return inv.Method + "#tx(" + n + ")" + } + return inv.Method + "#" + inv.GetTableName() +} diff --git a/client/orm/filter/opentracing/filter_test.go b/client/orm/filter/opentracing/filter_test.go new file mode 100644 index 0000000000..67de5c9e72 --- /dev/null +++ b/client/orm/filter/opentracing/filter_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + "testing" + "time" + + "github.com/opentracing/opentracing-go" + + "github.com/beego/beego/v2/client/orm" +) + +func TestFilterChainBuilderFilterChain(t *testing.T) { + next := func(ctx context.Context, inv *orm.Invocation) []interface{} { + inv.TxName = "Hello" + return []interface{}{} + } + + builder := &FilterChainBuilder{ + CustomSpanFunc: func(span opentracing.Span, ctx context.Context, inv *orm.Invocation) { + span.SetTag("hello", "hell") + }, + } + + inv := &orm.Invocation{ + Method: "Hello", + TxStartTime: time.Now(), + } + builder.FilterChain(next)(context.Background(), inv) +} diff --git a/client/orm/filter/prometheus/filter.go b/client/orm/filter/prometheus/filter.go new file mode 100644 index 0000000000..f525cebcbf --- /dev/null +++ b/client/orm/filter/prometheus/filter.go @@ -0,0 +1,92 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "context" + "strconv" + "strings" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/beego/beego/v2/client/orm" +) + +// FilterChainBuilder is an extension point, +// when we want to support some configuration, +// please use this structure +// this Filter's behavior looks a little bit strange +// for example: +// if we want to records the metrics of QuerySetter +// actually we only records metrics of invoking "QueryTable" +type FilterChainBuilder struct { + AppName string + ServerName string + RunMode string +} + +var ( + summaryVec prometheus.ObserverVec + initSummaryVec sync.Once +) + +func (builder *FilterChainBuilder) FilterChain(next orm.Filter) orm.Filter { + initSummaryVec.Do(func() { + summaryVec = prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "orm_operation", + ConstLabels: map[string]string{ + "server": builder.ServerName, + "env": builder.RunMode, + "appname": builder.AppName, + }, + Help: "The statics info for orm operation", + }, []string{"method", "name", "insideTx", "txName"}) + prometheus.MustRegister(summaryVec) + }) + + return func(ctx context.Context, inv *orm.Invocation) []interface{} { + startTime := time.Now() + res := next(ctx, inv) + endTime := time.Now() + dur := (endTime.Sub(startTime)) / time.Millisecond + + // if the TPS is too large, here may be some problem + // thinking about using goroutine pool + go builder.report(ctx, inv, dur) + return res + } +} + +func (builder *FilterChainBuilder) report(ctx context.Context, inv *orm.Invocation, dur time.Duration) { + // start a transaction, we don't record it + if strings.HasPrefix(inv.Method, "Begin") { + return + } + if inv.Method == "Commit" || inv.Method == "Rollback" { + builder.reportTxn(ctx, inv) + return + } + summaryVec.WithLabelValues(inv.Method, inv.GetTableName(), + strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur)) +} + +func (builder *FilterChainBuilder) reportTxn(ctx context.Context, inv *orm.Invocation) { + dur := time.Since(inv.TxStartTime) / time.Millisecond + summaryVec.WithLabelValues(inv.Method, inv.TxName, + strconv.FormatBool(inv.InsideTx), inv.TxName).Observe(float64(dur)) +} diff --git a/client/orm/filter/prometheus/filter_test.go b/client/orm/filter/prometheus/filter_test.go new file mode 100644 index 0000000000..912ae073b7 --- /dev/null +++ b/client/orm/filter/prometheus/filter_test.go @@ -0,0 +1,61 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestFilterChainBuilderFilterChain1(t *testing.T) { + next := func(ctx context.Context, inv *orm.Invocation) []interface{} { + inv.Method = "coming" + return []interface{}{} + } + builder := &FilterChainBuilder{} + filter := builder.FilterChain(next) + + assert.NotNil(t, summaryVec) + assert.NotNil(t, filter) + + inv := &orm.Invocation{} + filter(context.Background(), inv) + assert.Equal(t, "coming", inv.Method) + + inv = &orm.Invocation{ + Method: "Hello", + TxStartTime: time.Now(), + } + builder.reportTxn(context.Background(), inv) + + inv = &orm.Invocation{ + Method: "Begin", + } + + ctx := context.Background() + // it will be ignored + builder.report(ctx, inv, time.Second) + + inv.Method = "Commit" + builder.report(ctx, inv, time.Second) + + inv.Method = "Update" + builder.report(ctx, inv, time.Second) +} diff --git a/client/orm/filter_orm_decorator.go b/client/orm/filter_orm_decorator.go new file mode 100644 index 0000000000..1a4af25c58 --- /dev/null +++ b/client/orm/filter_orm_decorator.go @@ -0,0 +1,575 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "reflect" + "time" + + utils2 "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/core/utils" +) + +const ( + TxNameKey = "TxName" +) + +var ( + _ Ormer = new(filterOrmDecorator) + _ TxOrmer = new(filterOrmDecorator) +) + +type filterOrmDecorator struct { + ormer + TxBeginner + TxCommitter + + root Filter + + insideTx bool + txStartTime time.Time + txName string +} + +func NewFilterOrmDecorator(delegate Ormer, filterChains ...FilterChain) Ormer { + res := &filterOrmDecorator{ + ormer: delegate, + TxBeginner: delegate, + root: func(ctx context.Context, inv *Invocation) []interface{} { + return inv.execute(ctx) + }, + } + + for i := len(filterChains) - 1; i >= 0; i-- { + node := filterChains[i] + res.root = node(res.root) + } + return res +} + +func NewFilterTxOrmDecorator(delegate TxOrmer, root Filter, txName string) TxOrmer { + res := &filterOrmDecorator{ + ormer: delegate, + TxCommitter: delegate, + root: root, + insideTx: true, + txStartTime: time.Now(), + txName: txName, + } + return res +} + +func (f *filterOrmDecorator) Read(md interface{}, cols ...string) error { + return f.ReadWithCtx(context.Background(), md, cols...) +} + +// TODO: implementation code +func (f *filterOrmDecorator) ReadRaw(ctx context.Context, md interface{}, query string, args ...any) error { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "ReadRaw", + Args: []interface{}{md, query, args}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + err := f.ormer.ReadRaw(c, md, query, args...) + return []interface{}{err} + }, + } + res := f.root(ctx, inv) + return f.convertError(res[0]) +} + +func (f *filterOrmDecorator) ExecRaw(ctx context.Context, md interface{}, query string, args ...any) (sql.Result, error) { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "ExecRaw", + Args: []interface{}{md, query, args}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res, err := f.ormer.ExecRaw(c, md, query, args...) + return []interface{}{res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(sql.Result), f.convertError(res[1]) +} + +func (f *filterOrmDecorator) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "ReadWithCtx", + Args: []interface{}{md, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + err := f.ormer.ReadWithCtx(c, md, cols...) + return []interface{}{err} + }, + } + res := f.root(ctx, inv) + return f.convertError(res[0]) +} + +func (f *filterOrmDecorator) ReadForUpdate(md interface{}, cols ...string) error { + return f.ReadForUpdateWithCtx(context.Background(), md, cols...) +} + +func (f *filterOrmDecorator) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "ReadForUpdateWithCtx", + Args: []interface{}{md, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + err := f.ormer.ReadForUpdateWithCtx(c, md, cols...) + return []interface{}{err} + }, + } + res := f.root(ctx, inv) + return f.convertError(res[0]) +} + +func (f *filterOrmDecorator) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + return f.ReadOrCreateWithCtx(context.Background(), md, col1, cols...) +} + +func (f *filterOrmDecorator) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "ReadOrCreateWithCtx", + Args: []interface{}{md, col1, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + ok, res, err := f.ormer.ReadOrCreateWithCtx(c, md, col1, cols...) + return []interface{}{ok, res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(bool), res[1].(int64), f.convertError(res[2]) +} + +func (f *filterOrmDecorator) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) { + return f.LoadRelatedWithCtx(context.Background(), md, name, args...) +} + +func (f *filterOrmDecorator) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "LoadRelatedWithCtx", + Args: []interface{}{md, name, args}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res, err := f.ormer.LoadRelatedWithCtx(c, md, name, args...) + return []interface{}{res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) +} + +func (f *filterOrmDecorator) QueryM2M(md interface{}, name string) QueryM2Mer { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "QueryM2M", + Args: []interface{}{md, name}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res := f.ormer.QueryM2M(md, name) + return []interface{}{res} + }, + } + res := f.root(context.Background(), inv) + if res[0] == nil { + return nil + } + return res[0].(QueryM2Mer) +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` on QueryM2Mer suffix as replacement.") + return f.QueryM2M(md, name) +} + +func (f *filterOrmDecorator) QueryTable(ptrStructOrTableName interface{}) QuerySeter { + var ( + name string + md interface{} + mi *models.ModelInfo + ) + + if table, ok := ptrStructOrTableName.(string); ok { + name = table + } else { + name = models.GetFullName(utils2.IndirectType(reflect.TypeOf(ptrStructOrTableName))) + md = ptrStructOrTableName + } + + if m, ok := models.DefaultModelCache.GetByFullName(name); ok { + mi = m + } + + inv := &Invocation{ + Method: "QueryTable", + Args: []interface{}{ptrStructOrTableName}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + Md: md, + mi: mi, + f: func(c context.Context) []interface{} { + res := f.ormer.QueryTable(ptrStructOrTableName) + return []interface{}{res} + }, + } + res := f.root(context.Background(), inv) + + if res[0] == nil { + return nil + } + return res[0].(QuerySeter) +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (f *filterOrmDecorator) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) QuerySeter { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx`on QuerySeter suffix as replacement.") + return f.QueryTable(ptrStructOrTableName) +} + +func (f *filterOrmDecorator) DBStats() *sql.DBStats { + inv := &Invocation{ + Method: "DBStats", + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res := f.ormer.DBStats() + return []interface{}{res} + }, + } + res := f.root(context.Background(), inv) + + if res[0] == nil { + return nil + } + + return res[0].(*sql.DBStats) +} + +func (f *filterOrmDecorator) Insert(md interface{}) (int64, error) { + return f.InsertWithCtx(context.Background(), md) +} + +func (f *filterOrmDecorator) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "InsertWithCtx", + Args: []interface{}{md}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res, err := f.ormer.InsertWithCtx(c, md) + return []interface{}{res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) +} + +func (f *filterOrmDecorator) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { + return f.InsertOrUpdateWithCtx(context.Background(), md, colConflitAndArgs...) +} + +func (f *filterOrmDecorator) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "InsertOrUpdateWithCtx", + Args: []interface{}{md, colConflitAndArgs}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res, err := f.ormer.InsertOrUpdateWithCtx(c, md, colConflitAndArgs...) + return []interface{}{res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) +} + +func (f *filterOrmDecorator) InsertMulti(bulk int, mds interface{}) (int64, error) { + return f.InsertMultiWithCtx(context.Background(), bulk, mds) +} + +// InsertMultiWithCtx uses the first element's model info +func (f *filterOrmDecorator) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { + var ( + md interface{} + mi *models.ModelInfo + ) + + sind := reflect.Indirect(reflect.ValueOf(mds)) + + if (sind.Kind() == reflect.Array || sind.Kind() == reflect.Slice) && sind.Len() > 0 { + ind := reflect.Indirect(sind.Index(0)) + md = ind.Interface() + mi, _ = models.DefaultModelCache.GetByMd(md) + } + + inv := &Invocation{ + Method: "InsertMultiWithCtx", + Args: []interface{}{bulk, mds}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res, err := f.ormer.InsertMultiWithCtx(c, bulk, mds) + return []interface{}{res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) +} + +func (f *filterOrmDecorator) Update(md interface{}, cols ...string) (int64, error) { + return f.UpdateWithCtx(context.Background(), md, cols...) +} + +func (f *filterOrmDecorator) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "UpdateWithCtx", + Args: []interface{}{md, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res, err := f.ormer.UpdateWithCtx(c, md, cols...) + return []interface{}{res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) +} + +func (f *filterOrmDecorator) Delete(md interface{}, cols ...string) (int64, error) { + return f.DeleteWithCtx(context.Background(), md, cols...) +} + +func (f *filterOrmDecorator) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + mi, _ := models.DefaultModelCache.GetByMd(md) + inv := &Invocation{ + Method: "DeleteWithCtx", + Args: []interface{}{md, cols}, + Md: md, + mi: mi, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res, err := f.ormer.DeleteWithCtx(c, md, cols...) + return []interface{}{res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(int64), f.convertError(res[1]) +} + +func (f *filterOrmDecorator) Raw(query string, args ...interface{}) RawSeter { + return f.RawWithCtx(context.Background(), query, args...) +} + +func (f *filterOrmDecorator) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter { + inv := &Invocation{ + Method: "RawWithCtx", + Args: []interface{}{query, args}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res := f.ormer.RawWithCtx(c, query, args...) + return []interface{}{res} + }, + } + res := f.root(ctx, inv) + + if res[0] == nil { + return nil + } + return res[0].(RawSeter) +} + +func (f *filterOrmDecorator) Driver() Driver { + inv := &Invocation{ + Method: "Driver", + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res := f.ormer.Driver() + return []interface{}{res} + }, + } + res := f.root(context.Background(), inv) + if res[0] == nil { + return nil + } + return res[0].(Driver) +} + +func (f *filterOrmDecorator) Begin() (TxOrmer, error) { + return f.BeginWithCtxAndOpts(context.Background(), nil) +} + +func (f *filterOrmDecorator) BeginWithCtx(ctx context.Context) (TxOrmer, error) { + return f.BeginWithCtxAndOpts(ctx, nil) +} + +func (f *filterOrmDecorator) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { + return f.BeginWithCtxAndOpts(context.Background(), opts) +} + +func (f *filterOrmDecorator) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + inv := &Invocation{ + Method: "BeginWithCtxAndOpts", + Args: []interface{}{opts}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + f: func(c context.Context) []interface{} { + res, err := f.TxBeginner.BeginWithCtxAndOpts(c, opts) + res = NewFilterTxOrmDecorator(res, f.root, getTxNameFromCtx(c)) + return []interface{}{res, err} + }, + } + res := f.root(ctx, inv) + return res[0].(TxOrmer), f.convertError(res[1]) +} + +func (f *filterOrmDecorator) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { + return f.DoTxWithCtxAndOpts(context.Background(), nil, task) +} + +func (f *filterOrmDecorator) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { + return f.DoTxWithCtxAndOpts(ctx, nil, task) +} + +func (f *filterOrmDecorator) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { + return f.DoTxWithCtxAndOpts(context.Background(), opts, task) +} + +func (f *filterOrmDecorator) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { + inv := &Invocation{ + Method: "DoTxWithCtxAndOpts", + Args: []interface{}{opts, task}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: getTxNameFromCtx(ctx), + f: func(c context.Context) []interface{} { + err := doTxTemplate(c, f, opts, task) + return []interface{}{err} + }, + } + res := f.root(ctx, inv) + return f.convertError(res[0]) +} + +func (f *filterOrmDecorator) Commit() error { + inv := &Invocation{ + Method: "Commit", + Args: []interface{}{}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: f.txName, + f: func(c context.Context) []interface{} { + err := f.TxCommitter.Commit() + return []interface{}{err} + }, + } + res := f.root(context.Background(), inv) + return f.convertError(res[0]) +} + +func (f *filterOrmDecorator) Rollback() error { + inv := &Invocation{ + Method: "Rollback", + Args: []interface{}{}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: f.txName, + f: func(c context.Context) []interface{} { + err := f.TxCommitter.Rollback() + return []interface{}{err} + }, + } + res := f.root(context.Background(), inv) + return f.convertError(res[0]) +} + +func (f *filterOrmDecorator) RollbackUnlessCommit() error { + inv := &Invocation{ + Method: "RollbackUnlessCommit", + Args: []interface{}{}, + InsideTx: f.insideTx, + TxStartTime: f.txStartTime, + TxName: f.txName, + f: func(c context.Context) []interface{} { + err := f.TxCommitter.RollbackUnlessCommit() + return []interface{}{err} + }, + } + res := f.root(context.Background(), inv) + return f.convertError(res[0]) +} + +func (*filterOrmDecorator) convertError(v interface{}) error { + if v == nil { + return nil + } + return v.(error) +} + +func getTxNameFromCtx(ctx context.Context) string { + txName := "" + if n, ok := ctx.Value(TxNameKey).(string); ok { + txName = n + } + return txName +} diff --git a/client/orm/filter_orm_decorator_test.go b/client/orm/filter_orm_decorator_test.go new file mode 100644 index 0000000000..e190804196 --- /dev/null +++ b/client/orm/filter_orm_decorator_test.go @@ -0,0 +1,437 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/utils" +) + +func TestFilterOrmDecoratorRead(t *testing.T) { + register() + + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "ReadWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + return next(ctx, inv) + } + }) + + fte := &FilterTestEntity{} + err := od.Read(fte) + assert.NotNil(t, err) + assert.Equal(t, "read error", err.Error()) +} + +func TestFilterOrmDecoratorBeginTx(t *testing.T) { + register() + + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + if inv.Method == "BeginWithCtxAndOpts" { + assert.Equal(t, 1, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + } else if inv.Method == "Commit" { + assert.Equal(t, 0, len(inv.Args)) + assert.Equal(t, "Commit_tx", inv.TxName) + assert.Equal(t, "", inv.GetTableName()) + assert.True(t, inv.InsideTx) + } else if inv.Method == "Rollback" { + assert.Equal(t, 0, len(inv.Args)) + assert.Equal(t, "Rollback_tx", inv.TxName) + assert.Equal(t, "", inv.GetTableName()) + assert.True(t, inv.InsideTx) + } else { + t.Fail() + } + + return next(ctx, inv) + } + }) + to, err := od.Begin() + assert.True(t, validateBeginResult(t, to, err)) + + to, err = od.BeginWithOpts(nil) + assert.True(t, validateBeginResult(t, to, err)) + + ctx := context.WithValue(context.Background(), TxNameKey, "Commit_tx") + to, err = od.BeginWithCtx(ctx) + assert.True(t, validateBeginResult(t, to, err)) + + err = to.Commit() + assert.NotNil(t, err) + assert.Equal(t, "commit", err.Error()) + + ctx = context.WithValue(context.Background(), TxNameKey, "Rollback_tx") + to, err = od.BeginWithCtxAndOpts(ctx, nil) + assert.True(t, validateBeginResult(t, to, err)) + + err = to.Rollback() + assert.NotNil(t, err) + assert.Equal(t, "rollback", err.Error()) +} + +func TestFilterOrmDecoratorDBStats(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "DBStats", inv.Method) + assert.Equal(t, 0, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + return next(ctx, inv) + } + }) + res := od.DBStats() + assert.NotNil(t, res) + assert.Equal(t, -1, res.MaxOpenConnections) +} + +func TestFilterOrmDecoratorDelete(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "DeleteWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + return next(ctx, inv) + } + }) + res, err := od.Delete(&FilterTestEntity{}) + assert.NotNil(t, err) + assert.Equal(t, "delete error", err.Error()) + assert.Equal(t, int64(-2), res) +} + +func TestFilterOrmDecoratorDoTx(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + if inv.Method == "DoTxWithCtxAndOpts" { + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + } + return next(ctx, inv) + } + }) + + err := od.DoTx(func(c context.Context, txOrm TxOrmer) error { + return nil + }) + assert.NotNil(t, err) + + err = od.DoTxWithCtx(context.Background(), func(c context.Context, txOrm TxOrmer) error { + return nil + }) + assert.NotNil(t, err) + + err = od.DoTxWithOpts(nil, func(c context.Context, txOrm TxOrmer) error { + return nil + }) + assert.NotNil(t, err) + + od = NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + if inv.Method == "DoTxWithCtxAndOpts" { + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.Equal(t, "do tx name", inv.TxName) + assert.False(t, inv.InsideTx) + } + return next(ctx, inv) + } + }) + + ctx := context.WithValue(context.Background(), TxNameKey, "do tx name") + err = od.DoTxWithCtxAndOpts(ctx, nil, func(c context.Context, txOrm TxOrmer) error { + return nil + }) + assert.NotNil(t, err) +} + +func TestFilterOrmDecoratorDriver(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "Driver", inv.Method) + assert.Equal(t, 0, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + res := od.Driver() + assert.Nil(t, res) +} + +func TestFilterOrmDecoratorInsert(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "InsertWithCtx", inv.Method) + assert.Equal(t, 1, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + + i, err := od.Insert(&FilterTestEntity{}) + assert.NotNil(t, err) + assert.Equal(t, "insert error", err.Error()) + assert.Equal(t, int64(100), i) +} + +func TestFilterOrmDecoratorInsertMulti(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "InsertMultiWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + + bulk := []*FilterTestEntity{{}, {}} + i, err := od.InsertMulti(2, bulk) + assert.NotNil(t, err) + assert.Equal(t, "insert multi error", err.Error()) + assert.Equal(t, int64(2), i) +} + +func TestFilterOrmDecoratorInsertOrUpdate(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "InsertOrUpdateWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + i, err := od.InsertOrUpdate(&FilterTestEntity{}) + assert.NotNil(t, err) + assert.Equal(t, "insert or update error", err.Error()) + assert.Equal(t, int64(1), i) +} + +func TestFilterOrmDecoratorLoadRelated(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "LoadRelatedWithCtx", inv.Method) + assert.Equal(t, 3, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + i, err := od.LoadRelated(&FilterTestEntity{}, "hello") + assert.NotNil(t, err) + assert.Equal(t, "load related error", err.Error()) + assert.Equal(t, int64(99), i) +} + +func TestFilterOrmDecoratorQueryM2M(t *testing.T) { + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "QueryM2M", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + res := od.QueryM2M(&FilterTestEntity{}, "hello") + assert.Nil(t, res) +} + +func TestFilterOrmDecoratorQueryTable(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "QueryTable", inv.Method) + assert.Equal(t, 1, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + res := od.QueryTable(&FilterTestEntity{}) + assert.Nil(t, res) +} + +func TestFilterOrmDecoratorRaw(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "RawWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + res := od.Raw("hh") + assert.Nil(t, res) +} + +func TestFilterOrmDecoratorReadForUpdate(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "ReadForUpdateWithCtx", inv.Method) + assert.Equal(t, 2, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + err := od.ReadForUpdate(&FilterTestEntity{}) + assert.NotNil(t, err) + assert.Equal(t, "read for update error", err.Error()) +} + +func TestFilterOrmDecoratorReadOrCreate(t *testing.T) { + register() + o := &filterMockOrm{} + od := NewFilterOrmDecorator(o, func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + assert.Equal(t, "ReadOrCreateWithCtx", inv.Method) + assert.Equal(t, 3, len(inv.Args)) + assert.Equal(t, "FILTER_TEST", inv.GetTableName()) + assert.False(t, inv.InsideTx) + return next(ctx, inv) + } + }) + ok, i, err := od.ReadOrCreate(&FilterTestEntity{}, "name") + assert.NotNil(t, err) + assert.Equal(t, "read or create error", err.Error()) + assert.True(t, ok) + assert.Equal(t, int64(13), i) +} + +var _ Ormer = new(filterMockOrm) + +// filterMockOrm is only used in this test file +type filterMockOrm struct { + DoNothingOrm +} + +func (f *filterMockOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { + return true, 13, errors.New("read or create error") +} + +func (f *filterMockOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { + return errors.New("read for update error") +} + +func (f *filterMockOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { + return 99, errors.New("load related error") +} + +func (f *filterMockOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { + return 1, errors.New("insert or update error") +} + +func (f *filterMockOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { + return 2, errors.New("insert multi error") +} + +func (f *filterMockOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { + return 100, errors.New("insert error") +} + +func (f *filterMockOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(c context.Context, txOrm TxOrmer) error) error { + return task(ctx, nil) +} + +func (f *filterMockOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + return -2, errors.New("delete error") +} + +func (f *filterMockOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + return &filterMockOrm{}, errors.New("begin tx") +} + +func (f *filterMockOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { + return errors.New("read error") +} + +func (f *filterMockOrm) Commit() error { + return errors.New("commit") +} + +func (f *filterMockOrm) Rollback() error { + return errors.New("rollback") +} + +func (f *filterMockOrm) RollbackUnlessCommit() error { + return errors.New("rollback unless commit") +} + +func (f *filterMockOrm) DBStats() *sql.DBStats { + return &sql.DBStats{ + MaxOpenConnections: -1, + } +} + +func validateBeginResult(t *testing.T, to TxOrmer, err error) bool { + assert.NotNil(t, err) + assert.Equal(t, "begin tx", err.Error()) + _, ok := to.(*filterOrmDecorator).TxCommitter.(*filterMockOrm) + assert.True(t, ok) + return true +} + +var filterTestEntityRegisterOnce sync.Once + +type FilterTestEntity struct { + ID int + Name string +} + +func register() { + filterTestEntityRegisterOnce.Do(func() { + RegisterModel(&FilterTestEntity{}) + }) +} + +func (f *FilterTestEntity) TableName() string { + return "FILTER_TEST" +} diff --git a/client/orm/filter_test.go b/client/orm/filter_test.go new file mode 100644 index 0000000000..f9c86039b4 --- /dev/null +++ b/client/orm/filter_test.go @@ -0,0 +1,32 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddGlobalFilterChain(t *testing.T) { + AddGlobalFilterChain(func(next Filter) Filter { + return func(ctx context.Context, inv *Invocation) []interface{} { + return next(ctx, inv) + } + }) + assert.Equal(t, 1, len(globalFilterChains)) + globalFilterChains = nil +} diff --git a/client/orm/hints/db_hints.go b/client/orm/hints/db_hints.go new file mode 100644 index 0000000000..6578a5955e --- /dev/null +++ b/client/orm/hints/db_hints.go @@ -0,0 +1,103 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hints + +import ( + "github.com/beego/beego/v2/core/utils" +) + +const ( + // query level + KeyForceIndex = iota + KeyUseIndex + KeyIgnoreIndex + KeyForUpdate + KeyLimit + KeyOffset + KeyOrderBy + KeyRelDepth +) + +type Hint struct { + key interface{} + value interface{} +} + +var _ utils.KV = new(Hint) + +// GetKey return key +func (s *Hint) GetKey() interface{} { + return s.key +} + +// GetValue return value +func (s *Hint) GetValue() interface{} { + return s.value +} + +var _ utils.KV = new(Hint) + +// ForceIndex return a hint about ForceIndex +func ForceIndex(indexes ...string) *Hint { + return NewHint(KeyForceIndex, indexes) +} + +// UseIndex return a hint about UseIndex +func UseIndex(indexes ...string) *Hint { + return NewHint(KeyUseIndex, indexes) +} + +// IgnoreIndex return a hint about IgnoreIndex +func IgnoreIndex(indexes ...string) *Hint { + return NewHint(KeyIgnoreIndex, indexes) +} + +// ForUpdate return a hint about ForUpdate +func ForUpdate() *Hint { + return NewHint(KeyForUpdate, true) +} + +// DefaultRelDepth return a hint about DefaultRelDepth +func DefaultRelDepth() *Hint { + return NewHint(KeyRelDepth, true) +} + +// RelDepth return a hint about RelDepth +func RelDepth(d int) *Hint { + return NewHint(KeyRelDepth, d) +} + +// Limit return a hint about Limit +func Limit(d int64) *Hint { + return NewHint(KeyLimit, d) +} + +// Offset return a hint about Offset +func Offset(d int64) *Hint { + return NewHint(KeyOffset, d) +} + +// OrderBy return a hint about OrderBy +func OrderBy(s string) *Hint { + return NewHint(KeyOrderBy, s) +} + +// NewHint return a hint +func NewHint(key interface{}, value interface{}) *Hint { + return &Hint{ + key: key, + value: value, + } +} diff --git a/client/orm/hints/db_hints_test.go b/client/orm/hints/db_hints_test.go new file mode 100644 index 0000000000..ab18ee81f1 --- /dev/null +++ b/client/orm/hints/db_hints_test.go @@ -0,0 +1,127 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hints + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewHintTime(t *testing.T) { + key := "qweqwe" + value := time.Second + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHintInt(t *testing.T) { + key := "qweqwe" + value := 281230 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestNewHintFloat(t *testing.T) { + key := "qweqwe" + value := 21.2459753 + hint := NewHint(key, value) + + assert.Equal(t, hint.GetKey(), key) + assert.Equal(t, hint.GetValue(), value) +} + +func TestForceIndex(t *testing.T) { + s := []string{`f_index1`, `f_index2`, `f_index3`} + hint := ForceIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyForceIndex) +} + +func TestForceIndex0(t *testing.T) { + var s []string + hint := ForceIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyForceIndex) +} + +func TestIgnoreIndex(t *testing.T) { + s := []string{`i_index1`, `i_index2`, `i_index3`} + hint := IgnoreIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyIgnoreIndex) +} + +func TestIgnoreIndex0(t *testing.T) { + var s []string + hint := IgnoreIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyIgnoreIndex) +} + +func TestUseIndex(t *testing.T) { + s := []string{`u_index1`, `u_index2`, `u_index3`} + hint := UseIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyUseIndex) +} + +func TestUseIndex0(t *testing.T) { + var s []string + hint := UseIndex(s...) + assert.Equal(t, hint.GetValue(), s) + assert.Equal(t, hint.GetKey(), KeyUseIndex) +} + +func TestForUpdate(t *testing.T) { + hint := ForUpdate() + assert.Equal(t, hint.GetValue(), true) + assert.Equal(t, hint.GetKey(), KeyForUpdate) +} + +func TestDefaultRelDepth(t *testing.T) { + hint := DefaultRelDepth() + assert.Equal(t, hint.GetValue(), true) + assert.Equal(t, hint.GetKey(), KeyRelDepth) +} + +func TestRelDepth(t *testing.T) { + hint := RelDepth(157965) + assert.Equal(t, hint.GetValue(), 157965) + assert.Equal(t, hint.GetKey(), KeyRelDepth) +} + +func TestLimit(t *testing.T) { + hint := Limit(1579625) + assert.Equal(t, hint.GetValue(), int64(1579625)) + assert.Equal(t, hint.GetKey(), KeyLimit) +} + +func TestOffset(t *testing.T) { + hint := Offset(int64(1572123965)) + assert.Equal(t, hint.GetValue(), int64(1572123965)) + assert.Equal(t, hint.GetKey(), KeyOffset) +} + +func TestOrderBy(t *testing.T) { + hint := OrderBy(`-ID`) + assert.Equal(t, hint.GetValue(), `-ID`) + assert.Equal(t, hint.GetKey(), KeyOrderBy) +} diff --git a/client/orm/internal/buffers/buffers.go b/client/orm/internal/buffers/buffers.go new file mode 100644 index 0000000000..db341ff4e1 --- /dev/null +++ b/client/orm/internal/buffers/buffers.go @@ -0,0 +1,34 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffers + +import "github.com/valyala/bytebufferpool" + +var _ Buffer = &bytebufferpool.ByteBuffer{} + +type Buffer interface { + Write(p []byte) (int, error) + WriteString(s string) (int, error) + WriteByte(c byte) error + String() string +} + +func Get() Buffer { + return bytebufferpool.Get() +} + +func Put(bf Buffer) { + bytebufferpool.Put(bf.(*bytebufferpool.ByteBuffer)) +} diff --git a/client/orm/internal/logs/log.go b/client/orm/internal/logs/log.go new file mode 100644 index 0000000000..bbee9b70ee --- /dev/null +++ b/client/orm/internal/logs/log.go @@ -0,0 +1,17 @@ +package logs + +import ( + "io" + "log" +) + +// Log implement the log.Logger +type Log struct { + *log.Logger +} + +func NewLog(out io.Writer) *Log { + d := new(Log) + d.Logger = log.New(out, "[ORM]", log.LstdFlags) + return d +} diff --git a/client/orm/internal/models/models.go b/client/orm/internal/models/models.go new file mode 100644 index 0000000000..db6bc27706 --- /dev/null +++ b/client/orm/internal/models/models.go @@ -0,0 +1,397 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import ( + "fmt" + "reflect" + "runtime/debug" + "strings" + "sync" + + "github.com/beego/beego/v2/client/orm/qb/errs" +) + +var DefaultModelCache = NewModelCacheHandler() + +// ModelCache info collection +type ModelCache struct { + orders []string + cache map[string]*ModelInfo + cacheByFullName map[string]*ModelInfo + bootstrapOnce *sync.Once +} + +// NewModelCacheHandler generator of ModelCache +func NewModelCacheHandler() *ModelCache { + return &ModelCache{ + cache: make(map[string]*ModelInfo), + cacheByFullName: make(map[string]*ModelInfo), + bootstrapOnce: new(sync.Once), + } +} + +// All return all model info +func (mc *ModelCache) All() map[string]*ModelInfo { + m := make(map[string]*ModelInfo, len(mc.cache)) + for k, v := range mc.cache { + m[k] = v + } + return m +} + +func (mc *ModelCache) Empty() bool { + return len(mc.cache) == 0 +} + +func (mc *ModelCache) AllOrdered() []*ModelInfo { + m := make([]*ModelInfo, 0, len(mc.orders)) + for _, table := range mc.orders { + m = append(m, mc.cache[table]) + } + return m +} + +// Get model info by table name +func (mc *ModelCache) Get(table string) (mi *ModelInfo, ok bool) { + mi, ok = mc.cache[table] + return +} + +// GetByFullName model info by full name +func (mc *ModelCache) GetByFullName(name string) (mi *ModelInfo, ok bool) { + mi, ok = mc.cacheByFullName[name] + return +} + +func (mc *ModelCache) GetByMd(md interface{}) (*ModelInfo, bool) { + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + name := GetFullName(typ) + return mc.GetByFullName(name) +} + +func (mc *ModelCache) GetOrRegisterByMd(md interface{}) (*ModelInfo, error) { + model, ok := mc.GetByMd(md) + if !ok { + err := mc.Register("", true, md) + if err != nil { + return nil, err + } + model, ok = mc.GetByMd(md) + if !ok { + return nil, errs.ErrGetByMd + } + } + return model, nil +} + +// Set model info to collection +func (mc *ModelCache) Set(table string, mi *ModelInfo) *ModelInfo { + mii := mc.cache[table] + mc.cache[table] = mi + mc.cacheByFullName[mi.FullName] = mi + if mii == nil { + mc.orders = append(mc.orders, table) + } + return mii +} + +// Clean All model info. +func (mc *ModelCache) Clean() { + mc.orders = make([]string, 0) + mc.cache = make(map[string]*ModelInfo) + mc.cacheByFullName = make(map[string]*ModelInfo) + mc.bootstrapOnce = new(sync.Once) +} + +// Bootstrap Bootstrap for models +func (mc *ModelCache) Bootstrap() { + mc.bootstrapOnce.Do(func() { + mc.bootstrap() + }) +} + +func (mc *ModelCache) bootstrap() { + var ( + err error + models map[string]*ModelInfo + ) + // Set rel and reverse model + // RelManyToMany Set the relTable + models = mc.All() + for _, mi := range models { + for _, fi := range mi.Fields.Columns { + if fi.Rel || fi.Reverse { + elm := fi.AddrValue.Type().Elem() + if fi.FieldType == RelReverseMany || fi.FieldType == RelManyToMany { + elm = elm.Elem() + } + // check the rel or reverse model already Register + name := GetFullName(elm) + mii, ok := mc.GetByFullName(name) + if !ok || mii.Pkg != elm.PkgPath() { + err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss Register", fi.FullName, elm.String()) + goto end + } + fi.RelModelInfo = mii + + switch fi.FieldType { + case RelManyToMany: + if fi.RelThrough != "" { + if i := strings.LastIndex(fi.RelThrough, "."); i != -1 && len(fi.RelThrough) > (i+1) { + pn := fi.RelThrough[:i] + rmi, ok := mc.GetByFullName(fi.RelThrough) + if !ok || pn != rmi.Pkg { + err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.FullName, fi.RelThrough) + goto end + } + fi.RelThroughModelInfo = rmi + fi.RelTable = rmi.Table + } else { + err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.FullName, fi.RelThrough) + goto end + } + } else { + i := NewM2MModelInfo(mi, mii) + if fi.RelTable != "" { + i.Table = fi.RelTable + } + if v := mc.Set(i.Table, i); v != nil { + err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.RelTable) + goto end + } + fi.RelTable = i.Table + fi.RelThroughModelInfo = i + } + + fi.RelThroughModelInfo.IsThrough = true + } + } + } + } + + // check the rel filed while the relModelInfo also has filed point to current model + // if not exist, add a new field to the relModelInfo + models = mc.All() + for _, mi := range models { + for _, fi := range mi.Fields.FieldsRel { + switch fi.FieldType { + case RelForeignKey, RelOneToOne, RelManyToMany: + inModel := false + for _, ffi := range fi.RelModelInfo.Fields.FieldsReverse { + if ffi.RelModelInfo == mi { + inModel = true + break + } + } + if !inModel { + rmi := fi.RelModelInfo + ffi := new(FieldInfo) + ffi.Name = mi.Name + ffi.Column = ffi.Name + ffi.FullName = rmi.FullName + "." + ffi.Name + ffi.Reverse = true + ffi.RelModelInfo = mi + ffi.Mi = rmi + if fi.FieldType == RelOneToOne { + ffi.FieldType = RelReverseOne + } else { + ffi.FieldType = RelReverseMany + } + if !rmi.Fields.Add(ffi) { + added := false + for cnt := 0; cnt < 5; cnt++ { + ffi.Name = fmt.Sprintf("%s%d", mi.Name, cnt) + ffi.Column = ffi.Name + ffi.FullName = rmi.FullName + "." + ffi.Name + if added = rmi.Fields.Add(ffi); added { + break + } + } + if !added { + panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.FullName, ffi.FullName)) + } + } + } + } + } + } + + models = mc.All() + for _, mi := range models { + for _, fi := range mi.Fields.FieldsRel { + switch fi.FieldType { + case RelManyToMany: + for _, ffi := range fi.RelThroughModelInfo.Fields.FieldsRel { + switch ffi.FieldType { + case RelOneToOne, RelForeignKey: + if ffi.RelModelInfo == fi.RelModelInfo { + fi.ReverseFieldInfoTwo = ffi + } + if ffi.RelModelInfo == mi { + fi.ReverseField = ffi.Name + fi.ReverseFieldInfo = ffi + } + } + } + if fi.ReverseFieldInfoTwo == nil { + err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", + fi.RelThroughModelInfo.FullName) + goto end + } + } + } + } + + models = mc.All() + for _, mi := range models { + for _, fi := range mi.Fields.FieldsReverse { + switch fi.FieldType { + case RelReverseOne: + found := false + mForA: + for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelOneToOne] { + if ffi.RelModelInfo == mi { + found = true + fi.ReverseField = ffi.Name + fi.ReverseFieldInfo = ffi + + ffi.ReverseField = fi.Name + ffi.ReverseFieldInfo = fi + break mForA + } + } + if !found { + err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.FullName, fi.RelModelInfo.FullName) + goto end + } + case RelReverseMany: + found := false + mForB: + for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelForeignKey] { + if ffi.RelModelInfo == mi { + found = true + fi.ReverseField = ffi.Name + fi.ReverseFieldInfo = ffi + + ffi.ReverseField = fi.Name + ffi.ReverseFieldInfo = fi + + break mForB + } + } + if !found { + mForC: + for _, ffi := range fi.RelModelInfo.Fields.FieldsByType[RelManyToMany] { + conditions := fi.RelThrough != "" && fi.RelThrough == ffi.RelThrough || + fi.RelTable != "" && fi.RelTable == ffi.RelTable || + fi.RelThrough == "" && fi.RelTable == "" + if ffi.RelModelInfo == mi && conditions { + found = true + + fi.ReverseField = ffi.ReverseFieldInfoTwo.Name + fi.ReverseFieldInfo = ffi.ReverseFieldInfoTwo + fi.RelThroughModelInfo = ffi.RelThroughModelInfo + fi.ReverseFieldInfoTwo = ffi.ReverseFieldInfo + fi.ReverseFieldInfoM2M = ffi + ffi.ReverseFieldInfoM2M = fi + + break mForC + } + } + } + if !found { + err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.FullName, fi.RelModelInfo.FullName) + goto end + } + } + } + } + +end: + if err != nil { + fmt.Println(err) + debug.PrintStack() + } +} + +// Register Register models to model cache +func (mc *ModelCache) Register(prefixOrSuffixStr string, prefixOrSuffix bool, models ...interface{}) (err error) { + for _, model := range models { + val := reflect.ValueOf(model) + typ := reflect.Indirect(val).Type() + + if val.Kind() != reflect.Ptr { + err = fmt.Errorf(" cannot use non-ptr model struct `%s`", GetFullName(typ)) + return + } + // For this case: + // u := &User{} + // registerModel(&u) + if typ.Kind() == reflect.Ptr { + err = fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ) + return + } + if val.Elem().Kind() == reflect.Slice { + val = reflect.New(val.Elem().Type().Elem()) + } + table := GetTableName(val) + + if prefixOrSuffixStr != "" { + if prefixOrSuffix { + table = prefixOrSuffixStr + table + } else { + table = table + prefixOrSuffixStr + } + } + + // models's fullname is pkgpath + struct name + name := GetFullName(typ) + if _, ok := mc.GetByFullName(name); ok { + err = fmt.Errorf(" model `%s` repeat Register, must be unique\n", name) + return + } + + if _, ok := mc.Get(table); ok { + return nil + } + + mi := NewModelInfo(val) + if mi.Fields.Pk == nil { + outFor: + for _, fi := range mi.Fields.FieldsDB { + if strings.ToLower(fi.Name) == "id" { + switch fi.AddrValue.Elem().Kind() { + case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + fi.Auto = true + fi.Pk = true + mi.Fields.Pk = fi + break outFor + } + } + } + } + + mi.Table = table + mi.Pkg = typ.PkgPath() + mi.Model = model + mi.Manual = true + + mc.Set(table, mi) + } + return +} diff --git a/orm/models_fields.go b/client/orm/internal/models/models_fields.go similarity index 95% rename from orm/models_fields.go rename to client/orm/internal/models/models_fields.go index b4fad94f41..70e5aafa9b 100644 --- a/orm/models_fields.go +++ b/client/orm/internal/models/models_fields.go @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package models import ( "fmt" "strconv" "time" + + "github.com/beego/beego/v2/client/orm/internal/utils" ) // Define the Type enum @@ -85,7 +87,7 @@ func (e *BooleanField) SetRaw(value interface{}) error { case bool: e.Set(d) case string: - v, err := StrTo(d).Bool() + v, err := utils.StrTo(d).Bool() if err == nil { e.Set(v) } @@ -190,7 +192,7 @@ func (e *TimeField) SetRaw(value interface{}) error { case time.Time: e.Set(d) case string: - v, err := timeParse(d, formatTime) + v, err := utils.TimeParse(d, utils.FormatTime) if err == nil { e.Set(v) } @@ -249,7 +251,7 @@ func (e *DateField) SetRaw(value interface{}) error { case time.Time: e.Set(d) case string: - v, err := timeParse(d, formatDate) + v, err := utils.TimeParse(d, utils.FormatDate) if err == nil { e.Set(v) } @@ -299,7 +301,7 @@ func (e *DateTimeField) SetRaw(value interface{}) error { case time.Time: e.Set(d) case string: - v, err := timeParse(d, formatDateTime) + v, err := utils.TimeParse(d, utils.FormatDateTime) if err == nil { e.Set(v) } @@ -333,7 +335,7 @@ func (e *FloatField) Set(d float64) { // String return the string func (e *FloatField) String() string { - return ToStr(e.Value(), -1, 32) + return utils.ToStr(e.Value(), -1, 32) } // FieldType return the enum type @@ -349,7 +351,7 @@ func (e *FloatField) SetRaw(value interface{}) error { case float64: e.Set(d) case string: - v, err := StrTo(d).Float64() + v, err := utils.StrTo(d).Float64() if err == nil { e.Set(v) } @@ -383,7 +385,7 @@ func (e *SmallIntegerField) Set(d int16) { // String convert smallint to string func (e *SmallIntegerField) String() string { - return ToStr(e.Value()) + return utils.ToStr(e.Value()) } // FieldType return enum type SmallIntegerField @@ -397,7 +399,7 @@ func (e *SmallIntegerField) SetRaw(value interface{}) error { case int16: e.Set(d) case string: - v, err := StrTo(d).Int16() + v, err := utils.StrTo(d).Int16() if err == nil { e.Set(v) } @@ -431,7 +433,7 @@ func (e *IntegerField) Set(d int32) { // String convert Int32 to string func (e *IntegerField) String() string { - return ToStr(e.Value()) + return utils.ToStr(e.Value()) } // FieldType return the enum type @@ -445,7 +447,7 @@ func (e *IntegerField) SetRaw(value interface{}) error { case int32: e.Set(d) case string: - v, err := StrTo(d).Int32() + v, err := utils.StrTo(d).Int32() if err == nil { e.Set(v) } @@ -479,7 +481,7 @@ func (e *BigIntegerField) Set(d int64) { // String convert BigIntegerField to string func (e *BigIntegerField) String() string { - return ToStr(e.Value()) + return utils.ToStr(e.Value()) } // FieldType return enum type @@ -493,7 +495,7 @@ func (e *BigIntegerField) SetRaw(value interface{}) error { case int64: e.Set(d) case string: - v, err := StrTo(d).Int64() + v, err := utils.StrTo(d).Int64() if err == nil { e.Set(v) } @@ -527,7 +529,7 @@ func (e *PositiveSmallIntegerField) Set(d uint16) { // String convert uint16 to string func (e *PositiveSmallIntegerField) String() string { - return ToStr(e.Value()) + return utils.ToStr(e.Value()) } // FieldType return enum type @@ -541,7 +543,7 @@ func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error { case uint16: e.Set(d) case string: - v, err := StrTo(d).Uint16() + v, err := utils.StrTo(d).Uint16() if err == nil { e.Set(v) } @@ -575,7 +577,7 @@ func (e *PositiveIntegerField) Set(d uint32) { // String convert PositiveIntegerField to string func (e *PositiveIntegerField) String() string { - return ToStr(e.Value()) + return utils.ToStr(e.Value()) } // FieldType return enum type @@ -589,7 +591,7 @@ func (e *PositiveIntegerField) SetRaw(value interface{}) error { case uint32: e.Set(d) case string: - v, err := StrTo(d).Uint32() + v, err := utils.StrTo(d).Uint32() if err == nil { e.Set(v) } @@ -623,7 +625,7 @@ func (e *PositiveBigIntegerField) Set(d uint64) { // String convert PositiveBigIntegerField to string func (e *PositiveBigIntegerField) String() string { - return ToStr(e.Value()) + return utils.ToStr(e.Value()) } // FieldType return enum type @@ -637,7 +639,7 @@ func (e *PositiveBigIntegerField) SetRaw(value interface{}) error { case uint64: e.Set(d) case string: - v, err := StrTo(d).Uint64() + v, err := utils.StrTo(d).Uint64() if err == nil { e.Set(v) } diff --git a/orm/models_info_f.go b/client/orm/internal/models/models_info_f.go similarity index 57% rename from orm/models_info_f.go rename to client/orm/internal/models/models_info_f.go index 7044b0bdba..474935b402 100644 --- a/orm/models_info_f.go +++ b/client/orm/internal/models/models_info_f.go @@ -12,146 +12,150 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package models import ( "errors" "fmt" "reflect" "strings" + + "github.com/beego/beego/v2/client/orm/internal/utils" ) var errSkipField = errors.New("skip field") -// field info collection -type fields struct { - pk *fieldInfo - columns map[string]*fieldInfo - fields map[string]*fieldInfo - fieldsLow map[string]*fieldInfo - fieldsByType map[int][]*fieldInfo - fieldsRel []*fieldInfo - fieldsReverse []*fieldInfo - fieldsDB []*fieldInfo - rels []*fieldInfo - orders []string - dbcols []string +// Fields field info collection +type Fields struct { + Pk *FieldInfo + Columns map[string]*FieldInfo + Fields map[string]*FieldInfo + FieldsLow map[string]*FieldInfo + FieldsByType map[int][]*FieldInfo + FieldsRel []*FieldInfo + FieldsReverse []*FieldInfo + FieldsDB []*FieldInfo + Rels []*FieldInfo + Orders []string + DBcols []string } -// add field info -func (f *fields) Add(fi *fieldInfo) (added bool) { - if f.fields[fi.name] == nil && f.columns[fi.column] == nil { - f.columns[fi.column] = fi - f.fields[fi.name] = fi - f.fieldsLow[strings.ToLower(fi.name)] = fi +// Add adds field info +func (f *Fields) Add(fi *FieldInfo) (added bool) { + if f.Fields[fi.Name] == nil && f.Columns[fi.Column] == nil { + f.Columns[fi.Column] = fi + f.Fields[fi.Name] = fi + f.FieldsLow[strings.ToLower(fi.Name)] = fi } else { return } - if _, ok := f.fieldsByType[fi.fieldType]; !ok { - f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0) + if _, ok := f.FieldsByType[fi.FieldType]; !ok { + f.FieldsByType[fi.FieldType] = make([]*FieldInfo, 0) } - f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi) - f.orders = append(f.orders, fi.column) - if fi.dbcol { - f.dbcols = append(f.dbcols, fi.column) - f.fieldsDB = append(f.fieldsDB, fi) + f.FieldsByType[fi.FieldType] = append(f.FieldsByType[fi.FieldType], fi) + f.Orders = append(f.Orders, fi.Column) + if fi.DBcol { + f.DBcols = append(f.DBcols, fi.Column) + f.FieldsDB = append(f.FieldsDB, fi) } - if fi.rel { - f.fieldsRel = append(f.fieldsRel, fi) + if fi.Rel { + f.FieldsRel = append(f.FieldsRel, fi) } - if fi.reverse { - f.fieldsReverse = append(f.fieldsReverse, fi) + if fi.Reverse { + f.FieldsReverse = append(f.FieldsReverse, fi) } return true } -// get field info by name -func (f *fields) GetByName(name string) *fieldInfo { - return f.fields[name] +// GetByName get field info by name +func (f *Fields) GetByName(name string) *FieldInfo { + return f.Fields[name] } -// get field info by column name -func (f *fields) GetByColumn(column string) *fieldInfo { - return f.columns[column] +// GetByColumn get field info by column name +func (f *Fields) GetByColumn(column string) *FieldInfo { + return f.Columns[column] } -// get field info by string, name is prior -func (f *fields) GetByAny(name string) (*fieldInfo, bool) { - if fi, ok := f.fields[name]; ok { +// GetByAny get field info by string, name is prior +func (f *Fields) GetByAny(name string) (*FieldInfo, bool) { + if fi, ok := f.Fields[name]; ok { return fi, ok } - if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok { + if fi, ok := f.FieldsLow[strings.ToLower(name)]; ok { return fi, ok } - if fi, ok := f.columns[name]; ok { + if fi, ok := f.Columns[name]; ok { return fi, ok } return nil, false } -// create new field info collection -func newFields() *fields { - f := new(fields) - f.fields = make(map[string]*fieldInfo) - f.fieldsLow = make(map[string]*fieldInfo) - f.columns = make(map[string]*fieldInfo) - f.fieldsByType = make(map[int][]*fieldInfo) +// NewFields create new field info collection +func NewFields() *Fields { + f := new(Fields) + f.Fields = make(map[string]*FieldInfo) + f.FieldsLow = make(map[string]*FieldInfo) + f.Columns = make(map[string]*FieldInfo) + f.FieldsByType = make(map[int][]*FieldInfo) return f } -// single field info -type fieldInfo struct { - mi *modelInfo - fieldIndex []int - fieldType int - dbcol bool // table column fk and onetoone - inModel bool - name string - fullName string - column string - addrValue reflect.Value - sf reflect.StructField - auto bool - pk bool - null bool - index bool - unique bool - colDefault bool // whether has default tag - initial StrTo // store the default value - size int - toText bool - autoNow bool - autoNowAdd bool - rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true - reverse bool - reverseField string - reverseFieldInfo *fieldInfo - reverseFieldInfoTwo *fieldInfo - reverseFieldInfoM2M *fieldInfo - relTable string - relThrough string - relThroughModelInfo *modelInfo - relModelInfo *modelInfo - digits int - decimals int - isFielder bool // implement Fielder interface - onDelete string - description string +// FieldInfo single field info +type FieldInfo struct { + DBcol bool // table column fk and onetoone + InModel bool + Auto bool + Pk bool + Null bool + Index bool + Unique bool + ColDefault bool // whether has default tag + ToText bool + AutoNow bool + AutoNowAdd bool + Rel bool // if type equal to RelForeignKey, RelOneToOne, RelManyToMany then true + Reverse bool + IsFielder bool // implement Fielder interface + Mi *ModelInfo + FieldIndex []int + FieldType int + Name string + FullName string + Column string + AddrValue reflect.Value + Sf reflect.StructField + Initial utils.StrTo // store the default value + Size int + ReverseField string + ReverseFieldInfo *FieldInfo + ReverseFieldInfoTwo *FieldInfo + ReverseFieldInfoM2M *FieldInfo + RelTable string + RelThrough string + RelThroughModelInfo *ModelInfo + RelModelInfo *ModelInfo + Digits int + Decimals int + OnDelete string + Description string + TimePrecision *int + DBType string } -// new field info -func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) { +// NewFieldInfo new field info +func NewFieldInfo(mi *ModelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *FieldInfo, err error) { var ( tag string tagValue string - initial StrTo // store the default value + initial utils.StrTo // store the default value fieldType int attrs map[string]bool tags map[string]string addrField reflect.Value ) - fi = new(fieldInfo) + fi = new(FieldInfo) // if field which CanAddr is the follow type // A value is addressable if it is an element of a slice, @@ -167,7 +171,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN } } - attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName)) + attrs, tags = ParseStructTag(sf.Tag.Get(DefaultStructTagName)) if _, ok := attrs["-"]; ok { return nil, errSkipField @@ -177,7 +181,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN decimals := tags["decimals"] size := tags["size"] onDelete := tags["on_delete"] - + precision := tags["precision"] initial.Clear() if v, ok := tags["default"]; ok { initial.Set(v) @@ -186,14 +190,14 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN checkType: switch f := addrField.Interface().(type) { case Fielder: - fi.isFielder = true + fi.IsFielder = true if field.Kind() == reflect.Ptr { err = fmt.Errorf("the model Fielder can not be use ptr") goto end } fieldType = f.FieldType() if fieldType&IsRelField > 0 { - err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42") + err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/beego/beego/v2/blob/master/orm/models_fields.go#L24-L42") goto end } default: @@ -210,9 +214,9 @@ checkType: case "m2m": fieldType = RelManyToMany if tv := tags["rel_table"]; tv != "" { - fi.relTable = tv + fi.RelTable = tv } else if tv := tags["rel_through"]; tv != "" { - fi.relThrough = tv + fi.RelThrough = tv } break checkType default: @@ -230,9 +234,9 @@ checkType: case "many": fieldType = RelReverseMany if tv := tags["rel_table"]; tv != "" { - fi.relTable = tv + fi.RelTable = tv } else if tv := tags["rel_through"]; tv != "" { - fi.relThrough = tv + fi.RelThrough = tv } break checkType default: @@ -294,106 +298,118 @@ checkType: goto end } - fi.fieldType = fieldType - fi.name = sf.Name - fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) - fi.addrValue = addrField - fi.sf = sf - fi.fullName = mi.fullName + mName + "." + sf.Name - - fi.description = tags["description"] - fi.null = attrs["null"] - fi.index = attrs["index"] - fi.auto = attrs["auto"] - fi.pk = attrs["pk"] - fi.unique = attrs["unique"] + fi.FieldType = fieldType + fi.Name = sf.Name + fi.Column = getColumnName(fieldType, addrField, sf, tags["column"]) + fi.AddrValue = addrField + fi.Sf = sf + fi.FullName = mi.FullName + mName + "." + sf.Name + + fi.Description = tags["description"] + fi.Null = attrs["null"] + fi.Index = attrs["index"] + fi.Auto = attrs["auto"] + fi.DBType = tags["db_type"] + fi.Pk = attrs["pk"] + fi.Unique = attrs["unique"] // Mark object property if there is attribute "default" in the orm configuration if _, ok := tags["default"]; ok { - fi.colDefault = true + fi.ColDefault = true } switch fieldType { case RelManyToMany, RelReverseMany, RelReverseOne: - fi.null = false - fi.index = false - fi.auto = false - fi.pk = false - fi.unique = false + fi.Null = false + fi.Index = false + fi.Auto = false + fi.Pk = false + fi.Unique = false default: - fi.dbcol = true + fi.DBcol = true } switch fieldType { case RelForeignKey, RelOneToOne, RelManyToMany: - fi.rel = true + fi.Rel = true if fieldType == RelOneToOne { - fi.unique = true + fi.Unique = true } case RelReverseMany, RelReverseOne: - fi.reverse = true + fi.Reverse = true } - if fi.rel && fi.dbcol { + if fi.Rel && fi.DBcol { switch onDelete { - case odCascade, odDoNothing: - case odSetDefault: + case OdCascade, OdDoNothing: + case OdSetDefault: if !initial.Exist() { err = errors.New("on_delete: set_default need set field a default value") goto end } - case odSetNULL: - if !fi.null { + case OdSetNULL: + if !fi.Null { err = errors.New("on_delete: set_null need set field null") goto end } default: if onDelete == "" { - onDelete = odCascade + onDelete = OdCascade } else { err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete) goto end } } - fi.onDelete = onDelete + fi.OnDelete = onDelete } switch fieldType { case TypeBooleanField: case TypeVarCharField, TypeCharField, TypeJSONField, TypeJsonbField: if size != "" { - v, e := StrTo(size).Int32() + v, e := utils.StrTo(size).Int32() if e != nil { err = fmt.Errorf("wrong size value `%s`", size) } else { - fi.size = int(v) + fi.Size = int(v) } } else { - fi.size = 255 - fi.toText = true + fi.Size = 255 + fi.ToText = true } case TypeTextField: - fi.index = false - fi.unique = false + fi.Index = false + fi.Unique = false case TypeTimeField, TypeDateField, TypeDateTimeField: + if fieldType == TypeDateTimeField { + if precision != "" { + v, e := utils.StrTo(precision).Int() + if e != nil { + err = fmt.Errorf("convert %s to int error:%v", precision, e) + } else { + fi.TimePrecision = &v + } + } + } + if attrs["auto_now"] { - fi.autoNow = true + fi.AutoNow = true } else if attrs["auto_now_add"] { - fi.autoNowAdd = true + fi.AutoNowAdd = true } case TypeFloatField: case TypeDecimalField: d1 := digits d2 := decimals - v1, er1 := StrTo(d1).Int8() - v2, er2 := StrTo(d2).Int8() + v1, er1 := utils.StrTo(d1).Int8() + v2, er2 := utils.StrTo(d2).Int8() if er1 != nil || er2 != nil { err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1) goto end } - fi.digits = int(v1) - fi.decimals = int(v2) + fi.Digits = int(v1) + fi.Decimals = int(v2) default: switch { case fieldType&IsIntegerField > 0: @@ -402,33 +418,33 @@ checkType: } if fieldType&IsIntegerField == 0 { - if fi.auto { + if fi.Auto { err = fmt.Errorf("non-integer type cannot set auto") goto end } } - if fi.auto || fi.pk { - if fi.auto { + if fi.Auto || fi.Pk { + if fi.Auto { switch addrField.Elem().Kind() { case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: default: err = fmt.Errorf("auto primary key only support int, int32, int64, uint, uint32, uint64 but found `%s`", addrField.Elem().Kind()) goto end } - fi.pk = true + fi.Pk = true } - fi.null = false - fi.index = false - fi.unique = false + fi.Null = false + fi.Index = false + fi.Unique = false } - if fi.unique { - fi.index = false + if fi.Unique { + fi.Index = false } // can not set default for these type - if fi.auto || fi.pk || fi.unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { + if fi.Auto || fi.Pk || fi.Unique || fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField { initial.Clear() } @@ -462,7 +478,7 @@ checkType: } } - fi.initial = initial + fi.Initial = initial end: if err != nil { return nil, err diff --git a/client/orm/internal/models/models_info_m.go b/client/orm/internal/models/models_info_m.go new file mode 100644 index 0000000000..0dee0aa81c --- /dev/null +++ b/client/orm/internal/models/models_info_m.go @@ -0,0 +1,148 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import ( + "fmt" + "os" + "reflect" +) + +// ModelInfo single model info +type ModelInfo struct { + Manual bool + IsThrough bool + Pkg string + Name string + FullName string + Table string + Model interface{} + Fields *Fields + AddrField reflect.Value // store the original struct value + Uniques []string +} + +// NewModelInfo new model info +func NewModelInfo(val reflect.Value) (mi *ModelInfo) { + mi = &ModelInfo{} + mi.Fields = NewFields() + ind := reflect.Indirect(val) + mi.AddrField = val + mi.Name = ind.Type().Name() + mi.FullName = GetFullName(ind.Type()) + AddModelFields(mi, ind, "", []int{}) + return +} + +// AddModelFields index: FieldByIndex returns the nested field corresponding to index +func AddModelFields(mi *ModelInfo, ind reflect.Value, mName string, index []int) { + var ( + err error + fi *FieldInfo + sf reflect.StructField + ) + + for i := 0; i < ind.NumField(); i++ { + field := ind.Field(i) + sf = ind.Type().Field(i) + // if the field is unexported skip + if sf.PkgPath != "" { + continue + } + // add anonymous struct Fields + if sf.Anonymous { + AddModelFields(mi, field, mName+"."+sf.Name, append(index, i)) + continue + } + + fi, err = NewFieldInfo(mi, field, sf, mName) + if err == errSkipField { + err = nil + continue + } else if err != nil { + break + } + // record current field index + fi.FieldIndex = append(fi.FieldIndex, index...) + fi.FieldIndex = append(fi.FieldIndex, i) + fi.Mi = mi + fi.InModel = true + if !mi.Fields.Add(fi) { + err = fmt.Errorf("duplicate column name: %s", fi.Column) + break + } + if fi.Pk { + if mi.Fields.Pk != nil { + err = fmt.Errorf("one model must have one pk field only") + break + } else { + mi.Fields.Pk = fi + } + } + } + + if err != nil { + fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) + os.Exit(2) + } +} + +// NewM2MModelInfo combine related model info to new model info. +// prepare for relation models query. +func NewM2MModelInfo(m1, m2 *ModelInfo) (mi *ModelInfo) { + mi = new(ModelInfo) + mi.Fields = NewFields() + mi.Table = m1.Table + "_" + m2.Table + "s" + mi.Name = CamelString(mi.Table) + mi.FullName = m1.Pkg + "." + mi.Name + + fa := new(FieldInfo) // pk + f1 := new(FieldInfo) // m1 table RelForeignKey + f2 := new(FieldInfo) // m2 table RelForeignKey + fa.FieldType = TypeBigIntegerField + fa.Auto = true + fa.Pk = true + fa.DBcol = true + fa.Name = "Id" + fa.Column = "id" + fa.FullName = mi.FullName + "." + fa.Name + + f1.DBcol = true + f2.DBcol = true + f1.FieldType = RelForeignKey + f2.FieldType = RelForeignKey + f1.Name = CamelString(m1.Table) + f2.Name = CamelString(m2.Table) + f1.FullName = mi.FullName + "." + f1.Name + f2.FullName = mi.FullName + "." + f2.Name + f1.Column = m1.Table + "_id" + f2.Column = m2.Table + "_id" + f1.Rel = true + f2.Rel = true + f1.RelTable = m1.Table + f2.RelTable = m2.Table + f1.RelModelInfo = m1 + f2.RelModelInfo = m2 + f1.Mi = mi + f2.Mi = mi + + mi.Fields.Add(fa) + mi.Fields.Add(f1) + mi.Fields.Add(f2) + mi.Fields.Pk = fa + + mi.Uniques = []string{f1.Column, f2.Column} + return +} diff --git a/client/orm/internal/models/models_test.go b/client/orm/internal/models/models_test.go new file mode 100644 index 0000000000..0963cc0df8 --- /dev/null +++ b/client/orm/internal/models/models_test.go @@ -0,0 +1,49 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type Interface struct { + Id int + Name string + + Index1 string + Index2 string + + Unique1 string + Unique2 string +} + +func (i *Interface) TableIndex() [][]string { + return [][]string{{"index1"}, {"index2"}} +} + +func (i *Interface) TableUnique() [][]string { + return [][]string{{"unique1"}, {"unique2"}} +} + +func (i *Interface) TableName() string { + return "INTERFACE_" +} + +func (i *Interface) TableEngine() string { + return "innodb" +} + +func TestDbBase_GetTables(t *testing.T) { + registry := NewModelCacheHandler() + registry.Register("", true, &Interface{}) + mi, ok := registry.Get("INTERFACE_") + assert.True(t, ok) + assert.NotNil(t, mi) + + engine := GetTableEngine(mi.AddrField) + assert.Equal(t, "innodb", engine) + uniques := GetTableUnique(mi.AddrField) + assert.Equal(t, [][]string{{"unique1"}, {"unique2"}}, uniques) + indexes := GetTableIndex(mi.AddrField) + assert.Equal(t, [][]string{{"index1"}, {"index2"}}, indexes) +} diff --git a/orm/models_utils.go b/client/orm/internal/models/models_utils.go similarity index 64% rename from orm/models_utils.go rename to client/orm/internal/models/models_utils.go index 13a22aca54..3a96d23225 100644 --- a/orm/models_utils.go +++ b/client/orm/internal/models/models_utils.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package models import ( "database/sql" @@ -45,17 +45,33 @@ var supportTag = map[string]int{ "on_delete": 2, "type": 2, "description": 2, + "precision": 2, + "db_type": 2, } -// get reflect.Type name with package path. -func getFullName(typ reflect.Type) string { +type fn func(string) string + +var ( + NameStrategyMap = map[string]fn{ + DefaultNameStrategy: SnakeString, + SnakeAcronymNameStrategy: SnakeStringWithAcronym, + } + DefaultNameStrategy = "snakeString" + SnakeAcronymNameStrategy = "snakeStringWithAcronym" + NameStrategy = DefaultNameStrategy + defaultStructTagDelim = ";" + DefaultStructTagName = "orm" +) + +// GetFullName get reflect.Type name with package path. +func GetFullName(typ reflect.Type) string { return typ.PkgPath() + "." + typ.Name() } -// getTableName get struct table name. +// GetTableName get struct table name. // If the struct implement the TableName, then get the result as tablename // else use the struct name which will apply snakeString. -func getTableName(val reflect.Value) string { +func GetTableName(val reflect.Value) string { if fun := val.MethodByName("TableName"); fun.IsValid() { vals := fun.Call([]reflect.Value{}) // has return and the first val is string @@ -63,11 +79,11 @@ func getTableName(val reflect.Value) string { return vals[0].String() } } - return snakeString(reflect.Indirect(val).Type().Name()) + return SnakeString(reflect.Indirect(val).Type().Name()) } -// get table engine, mysiam or innodb. -func getTableEngine(val reflect.Value) string { +// GetTableEngine get table engine, myisam or innodb. +func GetTableEngine(val reflect.Value) string { fun := val.MethodByName("TableEngine") if fun.IsValid() { vals := fun.Call([]reflect.Value{}) @@ -78,8 +94,8 @@ func getTableEngine(val reflect.Value) string { return "" } -// get table index from method. -func getTableIndex(val reflect.Value) [][]string { +// GetTableIndex get table index from method. +func GetTableIndex(val reflect.Value) [][]string { fun := val.MethodByName("TableIndex") if fun.IsValid() { vals := fun.Call([]reflect.Value{}) @@ -92,8 +108,8 @@ func getTableIndex(val reflect.Value) [][]string { return nil } -// get table unique from method -func getTableUnique(val reflect.Value) [][]string { +// GetTableUnique get table unique from method +func GetTableUnique(val reflect.Value) [][]string { fun := val.MethodByName("TableUnique") if fun.IsValid() { vals := fun.Call([]reflect.Value{}) @@ -106,11 +122,26 @@ func getTableUnique(val reflect.Value) [][]string { return nil } +// IsApplicableTableForDB get whether the table needs to be created for the database alias +func IsApplicableTableForDB(val reflect.Value, db string) bool { + if !val.IsValid() { + return true + } + fun := val.MethodByName("IsApplicableTableForDB") + if fun.IsValid() { + vals := fun.Call([]reflect.Value{reflect.ValueOf(db)}) + if len(vals) > 0 && vals[0].Kind() == reflect.Bool { + return vals[0].Bool() + } + } + return true +} + // get snaked column name func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { column := col if col == "" { - column = nameStrategyMap[nameStrategy](sf.Name) + column = NameStrategyMap[NameStrategy](sf.Name) } switch ft { case RelForeignKey, RelOneToOne: @@ -202,8 +233,8 @@ func getFieldType(val reflect.Value) (ft int, err error) { return } -// parse struct tag string -func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) { +// ParseStructTag parse struct tag string +func ParseStructTag(data string) (attrs map[string]bool, tags map[string]string) { attrs = make(map[string]bool) tags = make(map[string]string) for _, v := range strings.Split(data, defaultStructTagDelim) { @@ -219,9 +250,73 @@ func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) v = v[i+1 : len(v)-1] tags[name] = v } - } else { - DebugLog.Println("unsupport orm tag", v) } } return } + +func SnakeStringWithAcronym(s string) string { + data := make([]byte, 0, len(s)*2) + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + before := false + after := false + if i > 0 { + before = s[i-1] >= 'a' && s[i-1] <= 'z' + } + if i+1 < num { + after = s[i+1] >= 'a' && s[i+1] <= 'z' + } + if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { + data = append(data, '_') + } + data = append(data, d) + } + return strings.ToLower(string(data)) +} + +// SnakeString snake string, XxYy to xx_yy , XxYY to xx_y_y +func SnakeString(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + return strings.ToLower(string(data)) +} + +// CamelString camel string, xx_yy to XxYy +func CamelString(s string) string { + data := make([]byte, 0, len(s)) + flag, num := true, len(s)-1 + for i := 0; i <= num; i++ { + d := s[i] + if d == '_' { + flag = true + continue + } else if flag { + if d >= 'a' && d <= 'z' { + d = d - 32 + } + flag = false + } + data = append(data, d) + } + return string(data) +} + +const ( + OdCascade = "cascade" + OdSetNULL = "set_null" + OdSetDefault = "set_default" + OdDoNothing = "do_nothing" +) diff --git a/orm/utils_test.go b/client/orm/internal/models/models_utils_test.go similarity index 75% rename from orm/utils_test.go rename to client/orm/internal/models/models_utils_test.go index 7d94cada45..40bffc6661 100644 --- a/orm/utils_test.go +++ b/client/orm/internal/models/models_utils_test.go @@ -1,10 +1,10 @@ -// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2020 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,27 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package models import ( + "reflect" "testing" + + "github.com/stretchr/testify/assert" ) -func TestCamelString(t *testing.T) { - snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} - camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} +type NotApplicableModel struct { + Id int +} - answer := make(map[string]string) - for i, v := range snake { - answer[v] = camel[i] - } +func (n *NotApplicableModel) IsApplicableTableForDB(db string) bool { + return db == "default" +} - for _, v := range snake { - res := camelString(v) - if res != answer[v] { - t.Error("Unit Test Fail:", v, res, answer[v]) - } - } +func TestIsApplicableTableForDB(t *testing.T) { + assert.False(t, IsApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "defa")) + assert.True(t, IsApplicableTableForDB(reflect.ValueOf(&NotApplicableModel{}), "default")) } func TestSnakeString(t *testing.T) { @@ -45,7 +44,7 @@ func TestSnakeString(t *testing.T) { } for _, v := range camel { - res := snakeString(v) + res := SnakeString(v) if res != answer[v] { t.Error("Unit Test Fail:", v, res, answer[v]) } @@ -62,7 +61,24 @@ func TestSnakeStringWithAcronym(t *testing.T) { } for _, v := range camel { - res := snakeStringWithAcronym(v) + res := SnakeStringWithAcronym(v) + if res != answer[v] { + t.Error("Unit Test Fail:", v, res, answer[v]) + } + } +} + +func TestCamelString(t *testing.T) { + snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"} + camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"} + + answer := make(map[string]string) + for i, v := range snake { + answer[v] = camel[i] + } + + for _, v := range snake { + res := CamelString(v) if res != answer[v] { t.Error("Unit Test Fail:", v, res, answer[v]) } diff --git a/logs/conn_test.go b/client/orm/internal/models/types.go similarity index 70% rename from logs/conn_test.go rename to client/orm/internal/models/types.go index 747fb890e0..f3b7989aba 100644 --- a/logs/conn_test.go +++ b/client/orm/internal/models/types.go @@ -1,4 +1,4 @@ -// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2023 beego-dev. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package logs +package models -import ( - "testing" -) - -func TestConn(t *testing.T) { - log := NewLogger(1000) - log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) - log.Informational("informational") +// Fielder define field info +type Fielder interface { + String() string + FieldType() int + SetRaw(interface{}) error + RawValue() interface{} } diff --git a/orm/utils.go b/client/orm/internal/utils/utils.go similarity index 57% rename from orm/utils.go rename to client/orm/internal/utils/utils.go index 783927713a..6aa713d7e2 100644 --- a/orm/utils.go +++ b/client/orm/internal/utils/utils.go @@ -12,29 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package orm +package utils import ( "fmt" "math/big" "reflect" "strconv" - "strings" "time" ) -type fn func(string) string - -var ( - nameStrategyMap = map[string]fn{ - defaultNameStrategy: snakeString, - SnakeAcronymNameStrategy: snakeStringWithAcronym, - } - defaultNameStrategy = "snakeString" - SnakeAcronymNameStrategy = "snakeStringWithAcronym" - nameStrategy = defaultNameStrategy -) - // StrTo is the target string type StrTo string @@ -49,12 +36,12 @@ func (f *StrTo) Set(v string) { // Clear string func (f *StrTo) Clear() { - *f = StrTo(0x1E) + *f = StrTo(rune(0x1E)) } // Exist check string exist func (f StrTo) Exist() bool { - return string(f) != string(0x1E) + return string(f) != string(rune(0x1E)) } // Bool string to bool @@ -75,7 +62,7 @@ func (f StrTo) Float64() (float64, error) { // Int string to int func (f StrTo) Int() (int, error) { - v, err := strconv.ParseInt(f.String(), 10, 32) + v, err := strconv.ParseInt(f.String(), 10, strconv.IntSize) return int(v), err } @@ -113,7 +100,7 @@ func (f StrTo) Int64() (int64, error) { // Uint string to uint func (f StrTo) Uint() (uint, error) { - v, err := strconv.ParseUint(f.String(), 10, 32) + v, err := strconv.ParseUint(f.String(), 10, strconv.IntSize) return uint(v), err } @@ -129,7 +116,7 @@ func (f StrTo) Uint16() (uint16, error) { return uint16(v), err } -// Uint32 string to uint31 +// Uint32 string to uint32 func (f StrTo) Uint32() (uint32, error) { v, err := strconv.ParseUint(f.String(), 10, 32) return uint32(v), err @@ -163,29 +150,29 @@ func ToStr(value interface{}, args ...int) (s string) { case bool: s = strconv.FormatBool(v) case float32: - s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + s = strconv.FormatFloat(float64(v), 'f', ArgInt(args).Get(0, -1), ArgInt(args).Get(1, 32)) case float64: - s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + s = strconv.FormatFloat(v, 'f', ArgInt(args).Get(0, -1), ArgInt(args).Get(1, 64)) case int: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10)) case int8: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10)) case int16: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10)) case int32: - s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + s = strconv.FormatInt(int64(v), ArgInt(args).Get(0, 10)) case int64: - s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + s = strconv.FormatInt(v, ArgInt(args).Get(0, 10)) case uint: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10)) case uint8: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10)) case uint16: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10)) case uint32: - s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + s = strconv.FormatUint(uint64(v), ArgInt(args).Get(0, 10)) case uint64: - s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + s = strconv.FormatUint(v, ArgInt(args).Get(0, 10)) case string: s = v case []byte: @@ -210,77 +197,10 @@ func ToInt64(value interface{}) (d int64) { return } -func snakeStringWithAcronym(s string) string { - data := make([]byte, 0, len(s)*2) - num := len(s) - for i := 0; i < num; i++ { - d := s[i] - before := false - after := false - if i > 0 { - before = s[i-1] >= 'a' && s[i-1] <= 'z' - } - if i+1 < num { - after = s[i+1] >= 'a' && s[i+1] <= 'z' - } - if i > 0 && d >= 'A' && d <= 'Z' && (before || after) { - data = append(data, '_') - } - data = append(data, d) - } - return strings.ToLower(string(data[:])) -} - -// snake string, XxYy to xx_yy , XxYY to xx_y_y -func snakeString(s string) string { - data := make([]byte, 0, len(s)*2) - j := false - num := len(s) - for i := 0; i < num; i++ { - d := s[i] - if i > 0 && d >= 'A' && d <= 'Z' && j { - data = append(data, '_') - } - if d != '_' { - j = true - } - data = append(data, d) - } - return strings.ToLower(string(data[:])) -} - -// SetNameStrategy set different name strategy -func SetNameStrategy(s string) { - if SnakeAcronymNameStrategy != s { - nameStrategy = defaultNameStrategy - } - nameStrategy = s -} +type ArgString []string -// camel string, xx_yy to XxYy -func camelString(s string) string { - data := make([]byte, 0, len(s)) - flag, num := true, len(s)-1 - for i := 0; i <= num; i++ { - d := s[i] - if d == '_' { - flag = true - continue - } else if flag { - if d >= 'a' && d <= 'z' { - d = d - 32 - } - flag = false - } - data = append(data, d) - } - return string(data[:]) -} - -type argString []string - -// get string by index from string slice -func (a argString) Get(i int, args ...string) (r string) { +// Get get string by index from string slice +func (a ArgString) Get(i int, args ...string) (r string) { if i >= 0 && i < len(a) { r = a[i] } else if len(args) > 0 { @@ -289,10 +209,10 @@ func (a argString) Get(i int, args ...string) (r string) { return } -type argInt []int +type ArgInt []int -// get int by index from int slice -func (a argInt) Get(i int, args ...int) (r int) { +// Get get int by index from int slice +func (a ArgInt) Get(i int, args ...int) (r int) { if i >= 0 && i < len(a) { r = a[i] } @@ -302,18 +222,28 @@ func (a argInt) Get(i int, args ...int) (r int) { return } -// parse time to string with location -func timeParse(dateString, format string) (time.Time, error) { +// TimeParse parse time to string with location +func TimeParse(dateString, format string) (time.Time, error) { tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) return tp, err } -// get pointer indirect type -func indirectType(v reflect.Type) reflect.Type { +// IndirectType get pointer indirect type +func IndirectType(v reflect.Type) reflect.Type { switch v.Kind() { case reflect.Ptr: - return indirectType(v.Elem()) + return IndirectType(v.Elem()) default: return v } } + +const ( + FormatTime = "15:04:05" + FormatDate = "2006-01-02" + FormatDateTime = "2006-01-02 15:04:05" +) + +var ( + DefaultTimeLoc = time.Local +) diff --git a/client/orm/invocation.go b/client/orm/invocation.go new file mode 100644 index 0000000000..21cbdc42eb --- /dev/null +++ b/client/orm/invocation.go @@ -0,0 +1,60 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "time" + + "github.com/beego/beego/v2/client/orm/internal/models" +) + +// Invocation represents an "Orm" invocation +type Invocation struct { + Method string + // Md may be nil in some cases. It depends on method + Md interface{} + // the args are All arguments except context.Context + Args []interface{} + + mi *models.ModelInfo + // f is the Orm operation + f func(ctx context.Context) []interface{} + + // insideTx indicates whether this is inside a transaction + InsideTx bool + TxStartTime time.Time + TxName string +} + +func (inv *Invocation) GetTableName() string { + if inv.mi != nil { + return inv.mi.Table + } + return "" +} + +func (inv *Invocation) execute(ctx context.Context) []interface{} { + return inv.f(ctx) +} + +// GetPkFieldName return the primary key of this table +// if not found, "" is returned +func (inv *Invocation) GetPkFieldName() string { + if inv.mi.Fields.Pk != nil { + return inv.mi.Fields.Pk.Name + } + return "" +} diff --git a/migration/ddl.go b/client/orm/migration/ddl.go similarity index 84% rename from migration/ddl.go rename to client/orm/migration/ddl.go index cd2c1c49d8..bcbfdd5a06 100644 --- a/migration/ddl.go +++ b/client/orm/migration/ddl.go @@ -17,7 +17,7 @@ package migration import ( "fmt" - "github.com/astaxie/beego/logs" + "github.com/beego/beego/v2/core/logs" ) // Index struct defines the structure of Index Columns @@ -31,7 +31,7 @@ type Unique struct { Columns []*Column } -//Column struct defines a single column of a table +// Column struct defines a single column of a table type Column struct { Name string Inc string @@ -84,7 +84,7 @@ func (m *Migration) NewCol(name string) *Column { return col } -//PriCol creates a new primary column and attaches it to m struct +// PriCol creates a new primary column and attaches it to m struct func (m *Migration) PriCol(name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -92,7 +92,7 @@ func (m *Migration) PriCol(name string) *Column { return col } -//UniCol creates / appends columns to specified unique key and attaches it to m struct +// UniCol creates / appends columns to specified unique key and attaches it to m struct func (m *Migration) UniCol(uni, name string) *Column { col := &Column{Name: name} m.AddColumns(col) @@ -114,34 +114,33 @@ func (m *Migration) UniCol(uni, name string) *Column { return col } -//ForeignCol creates a new foreign column and returns the instance of column +// ForeignCol creates a new foreign column and returns the instance of column func (m *Migration) ForeignCol(colname, foreigncol, foreigntable string) (foreign *Foreign) { - foreign = &Foreign{ForeignColumn: foreigncol, ForeignTable: foreigntable} foreign.Name = colname m.AddForeign(foreign) return foreign } -//SetOnDelete sets the on delete of foreign +// SetOnDelete sets the on delete of foreign func (foreign *Foreign) SetOnDelete(del string) *Foreign { foreign.OnDelete = "ON DELETE" + del return foreign } -//SetOnUpdate sets the on update of foreign +// SetOnUpdate sets the on update of foreign func (foreign *Foreign) SetOnUpdate(update string) *Foreign { foreign.OnUpdate = "ON UPDATE" + update return foreign } -//Remove marks the columns to be removed. -//it allows reverse m to create the column. +// Remove marks the columns to be removed. +// it allows reverse m to create the column. func (c *Column) Remove() { c.remove = true } -//SetAuto enables auto_increment of column (can be used once) +// SetAuto enables auto_increment of column (can be used once) func (c *Column) SetAuto(inc bool) *Column { if inc { c.Inc = "auto_increment" @@ -149,24 +148,23 @@ func (c *Column) SetAuto(inc bool) *Column { return c } -//SetNullable sets the column to be null +// SetNullable sets the column to be null func (c *Column) SetNullable(null bool) *Column { if null { c.Null = "" - } else { c.Null = "NOT NULL" } return c } -//SetDefault sets the default value, prepend with "DEFAULT " +// SetDefault sets the default value, prepend with "DEFAULT " func (c *Column) SetDefault(def string) *Column { c.Default = "DEFAULT " + def return c } -//SetUnsigned sets the column to be unsigned int +// SetUnsigned sets the column to be unsigned int func (c *Column) SetUnsigned(unsign bool) *Column { if unsign { c.Unsign = "UNSIGNED" @@ -174,30 +172,29 @@ func (c *Column) SetUnsigned(unsign bool) *Column { return c } -//SetDataType sets the dataType of the column +// SetDataType sets the dataType of the column func (c *Column) SetDataType(dataType string) *Column { c.DataType = dataType return c } -//SetOldNullable allows reverting to previous nullable on reverse ms +// SetOldNullable allows reverting to previous nullable on reverse ms func (c *RenameColumn) SetOldNullable(null bool) *RenameColumn { if null { c.OldNull = "" - } else { c.OldNull = "NOT NULL" } return c } -//SetOldDefault allows reverting to previous default on reverse ms +// SetOldDefault allows reverting to previous default on reverse ms func (c *RenameColumn) SetOldDefault(def string) *RenameColumn { c.OldDefault = def return c } -//SetOldUnsigned allows reverting to previous unsgined on reverse ms +// SetOldUnsigned allows reverting to previous unsgined on reverse ms func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { if unsign { c.OldUnsign = "UNSIGNED" @@ -205,66 +202,64 @@ func (c *RenameColumn) SetOldUnsigned(unsign bool) *RenameColumn { return c } -//SetOldDataType allows reverting to previous datatype on reverse ms +// SetOldDataType allows reverting to previous datatype on reverse ms func (c *RenameColumn) SetOldDataType(dataType string) *RenameColumn { c.OldDataType = dataType return c } -//SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) +// SetPrimary adds the columns to the primary key (can only be used any number of times in only one m) func (c *Column) SetPrimary(m *Migration) *Column { m.Primary = append(m.Primary, c) return c } -//AddColumnsToUnique adds the columns to Unique Struct +// AddColumnsToUnique adds the columns to Unique Struct func (unique *Unique) AddColumnsToUnique(columns ...*Column) *Unique { - unique.Columns = append(unique.Columns, columns...) return unique } -//AddColumns adds columns to m struct +// AddColumns adds columns to m struct func (m *Migration) AddColumns(columns ...*Column) *Migration { - m.Columns = append(m.Columns, columns...) return m } -//AddPrimary adds the column to primary in m struct +// AddPrimary adds the column to primary in m struct func (m *Migration) AddPrimary(primary *Column) *Migration { m.Primary = append(m.Primary, primary) return m } -//AddUnique adds the column to unique in m struct +// AddUnique adds the column to unique in m struct func (m *Migration) AddUnique(unique *Unique) *Migration { m.Uniques = append(m.Uniques, unique) return m } -//AddForeign adds the column to foreign in m struct +// AddForeign adds the column to foreign in m struct func (m *Migration) AddForeign(foreign *Foreign) *Migration { m.Foreigns = append(m.Foreigns, foreign) return m } -//AddIndex adds the column to index in m struct +// AddIndex adds the column to index in m struct func (m *Migration) AddIndex(index *Index) *Migration { m.Indexes = append(m.Indexes, index) return m } -//RenameColumn allows renaming of columns +// RenameColumn allows renaming of columns func (m *Migration) RenameColumn(from, to string) *RenameColumn { rename := &RenameColumn{OldName: from, NewName: to} m.Renames = append(m.Renames, rename) return rename } -//GetSQL returns the generated sql depending on ModifyType +// GetSQL returns the generated sql depending on ModifyType func (m *Migration) GetSQL() (sql string) { sql = "" switch m.ModifyType { @@ -279,7 +274,7 @@ func (m *Migration) GetSQL() (sql string) { } if len(m.Primary) > 0 { - sql += fmt.Sprintf(",\n PRIMARY KEY( ") + sql += ",\n PRIMARY KEY( " } for index, column := range m.Primary { sql += fmt.Sprintf(" `%s`", column.Name) @@ -289,7 +284,7 @@ func (m *Migration) GetSQL() (sql string) { } if len(m.Primary) > 0 { - sql += fmt.Sprintf(")") + sql += ")" } for _, unique := range m.Uniques { @@ -300,7 +295,7 @@ func (m *Migration) GetSQL() (sql string) { sql += "," } } - sql += fmt.Sprintf(")") + sql += ")" } for _, foreign := range m.Foreigns { sql += fmt.Sprintf(",\n `%s` %s %s %s %s %s", foreign.Name, foreign.DataType, foreign.Unsign, foreign.Null, foreign.Inc, foreign.Default) @@ -361,7 +356,7 @@ func (m *Migration) GetSQL() (sql string) { } if len(m.Primary) > 0 { - sql += fmt.Sprintf("\n DROP PRIMARY KEY,") + sql += "\n DROP PRIMARY KEY," } for index, unique := range m.Uniques { diff --git a/migration/doc.go b/client/orm/migration/doc.go similarity index 100% rename from migration/doc.go rename to client/orm/migration/doc.go diff --git a/migration/migration.go b/client/orm/migration/migration.go similarity index 90% rename from migration/migration.go rename to client/orm/migration/migration.go index 97e10c2e66..d4937f790c 100644 --- a/migration/migration.go +++ b/client/orm/migration/migration.go @@ -33,8 +33,8 @@ import ( "strings" "time" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/orm" + "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/core/logs" ) // const the data format for the bee generate migration datatype @@ -52,7 +52,7 @@ type Migrationer interface { GetCreated() int64 } -//Migration defines the migrations by either SQL or DDL +// Migration defines the migrations by either SQL or DDL type Migration struct { sqls []string Created string @@ -72,9 +72,7 @@ type Migration struct { RemoveForeigns []*Foreign } -var ( - migrationMap map[string]Migrationer -) +var migrationMap map[string]Migrationer func init() { migrationMap = make(map[string]Migrationer) @@ -82,7 +80,6 @@ func init() { // Up implement in the Inheritance struct for upgrade func (m *Migration) Up() { - switch m.ModifyType { case "reverse": m.ModifyType = "alter" @@ -94,7 +91,6 @@ func (m *Migration) Up() { // Down implement in the Inheritance struct for down func (m *Migration) Down() { - switch m.ModifyType { case "alter": m.ModifyType = "reverse" @@ -104,7 +100,7 @@ func (m *Migration) Down() { m.sqls = append(m.sqls, m.GetSQL()) } -//Migrate adds the SQL to the execution list +// Migrate adds the SQL to the execution list func (m *Migration) Migrate(migrationType string) { m.ModifyType = migrationType m.sqls = append(m.sqls, m.GetSQL()) @@ -140,7 +136,7 @@ func (m *Migration) addOrUpdateRecord(name, status string) error { status = "rollback" p, err := o.Raw("update migrations set status = ?, rollback_statements = ?, created_at = ? where name = ?").Prepare() if err != nil { - return nil + return err } _, err = p.Exec(status, strings.Join(m.sqls, "; "), time.Now().Format(DBDateFormat), name) return err @@ -176,8 +172,9 @@ func Register(name string, m Migrationer) error { func Upgrade(lasttime int64) error { sm := sortMap(migrationMap) i := 0 + migs, _ := getAllMigrations() for _, v := range sm { - if v.created > lasttime { + if _, ok := migs[v.name]; !ok { logs.Info("start upgrade", v.name) v.m.Reset() v.m.Up() @@ -310,3 +307,21 @@ func isRollBack(name string) bool { } return false } + +func getAllMigrations() (map[string]string, error) { + o := orm.NewOrm() + var maps []orm.Params + migs := make(map[string]string) + num, err := o.Raw("select * from migrations order by id_migration desc").Values(&maps) + if err != nil { + logs.Info("get name has error", err) + return migs, err + } + if num > 0 { + for _, v := range maps { + name := v["name"].(string) + migs[name] = v["status"].(string) + } + } + return migs, nil +} diff --git a/client/orm/mock/condition.go b/client/orm/mock/condition.go new file mode 100644 index 0000000000..eda888243e --- /dev/null +++ b/client/orm/mock/condition.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +type Mock struct { + cond Condition + resp []interface{} + cb func(inv *orm.Invocation) +} + +func NewMock(cond Condition, resp []interface{}, cb func(inv *orm.Invocation)) *Mock { + return &Mock{ + cond: cond, + resp: resp, + cb: cb, + } +} + +type Condition interface { + Match(ctx context.Context, inv *orm.Invocation) bool +} + +type SimpleCondition struct { + tableName string + method string +} + +func NewSimpleCondition(tableName string, methodName string) Condition { + return &SimpleCondition{ + tableName: tableName, + method: methodName, + } +} + +func (s *SimpleCondition) Match(ctx context.Context, inv *orm.Invocation) bool { + res := true + if len(s.tableName) != 0 { + res = res && (s.tableName == inv.GetTableName()) + } + + if len(s.method) != 0 { + res = res && (s.method == inv.Method) + } + return res +} diff --git a/client/orm/mock/condition_test.go b/client/orm/mock/condition_test.go new file mode 100644 index 0000000000..7f646e70fa --- /dev/null +++ b/client/orm/mock/condition_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestSimpleCondition_Match(t *testing.T) { + cond := NewSimpleCondition("", "") + res := cond.Match(context.Background(), &orm.Invocation{}) + assert.True(t, res) + cond = NewSimpleCondition("hello", "") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{})) + + cond = NewSimpleCondition("", "A") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{ + Method: "B", + })) + + assert.True(t, cond.Match(context.Background(), &orm.Invocation{ + Method: "A", + })) +} diff --git a/client/orm/mock/context.go b/client/orm/mock/context.go new file mode 100644 index 0000000000..ca251c5d2f --- /dev/null +++ b/client/orm/mock/context.go @@ -0,0 +1,40 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/core/logs" +) + +type mockCtxKeyType string + +const mockCtxKey = mockCtxKeyType("beego-orm-mock") + +func CtxWithMock(ctx context.Context, mock ...*Mock) context.Context { + return context.WithValue(ctx, mockCtxKey, mock) +} + +func mockFromCtx(ctx context.Context) []*Mock { + ms := ctx.Value(mockCtxKey) + if ms != nil { + if res, ok := ms.([]*Mock); ok { + return res + } + logs.Error("mockCtxKey found in context, but value is not type []*Mock") + } + return nil +} diff --git a/client/orm/mock/context_test.go b/client/orm/mock/context_test.go new file mode 100644 index 0000000000..a3ed1e905a --- /dev/null +++ b/client/orm/mock/context_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCtx(t *testing.T) { + ms := make([]*Mock, 0, 4) + ctx := CtxWithMock(context.Background(), ms...) + res := mockFromCtx(ctx) + assert.Equal(t, ms, res) +} diff --git a/client/orm/mock/mock.go b/client/orm/mock/mock.go new file mode 100644 index 0000000000..7aca914c4a --- /dev/null +++ b/client/orm/mock/mock.go @@ -0,0 +1,71 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +var stub = newOrmStub() + +func init() { + orm.AddGlobalFilterChain(stub.FilterChain) +} + +type Stub interface { + Mock(m *Mock) + Clear() +} + +type OrmStub struct { + ms []*Mock +} + +func StartMock() Stub { + return stub +} + +func newOrmStub() *OrmStub { + return &OrmStub{ + ms: make([]*Mock, 0, 4), + } +} + +func (o *OrmStub) Mock(m *Mock) { + o.ms = append(o.ms, m) +} + +func (o *OrmStub) Clear() { + o.ms = make([]*Mock, 0, 4) +} + +func (o *OrmStub) FilterChain(next orm.Filter) orm.Filter { + return func(ctx context.Context, inv *orm.Invocation) []interface{} { + ms := mockFromCtx(ctx) + ms = append(ms, o.ms...) + + for _, mock := range ms { + if mock.cond.Match(ctx, inv) { + if mock.cb != nil { + mock.cb(inv) + } + return mock.resp + } + } + return next(ctx, inv) + } +} diff --git a/client/orm/mock/mock_orm.go b/client/orm/mock/mock_orm.go new file mode 100644 index 0000000000..77fe24966d --- /dev/null +++ b/client/orm/mock/mock_orm.go @@ -0,0 +1,169 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "database/sql" + "os" + "path/filepath" + + _ "github.com/mattn/go-sqlite3" + + "github.com/beego/beego/v2/client/orm" +) + +func init() { + RegisterMockDB("default") +} + +// RegisterMockDB create an "virtual DB" by using sqllite +// you should not +func RegisterMockDB(name string) { + source := filepath.Join(os.TempDir(), name+".db") + _ = orm.RegisterDataBase(name, "sqlite3", source) +} + +// MockTable only check table name +func MockTable(tableName string, resp ...interface{}) *Mock { + return NewMock(NewSimpleCondition(tableName, ""), resp, nil) +} + +// MockMethod only check method name +func MockMethod(method string, resp ...interface{}) *Mock { + return NewMock(NewSimpleCondition("", method), resp, nil) +} + +// MockRead support orm.Read and orm.ReadWithCtx +// cb is used to mock read data from DB +func MockRead(tableName string, cb func(data interface{}), err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadWithCtx"), []interface{}{err}, func(inv *orm.Invocation) { + if cb != nil { + cb(inv.Args[0]) + } + }) +} + +// MockReadForUpdateWithCtx support ReadForUpdate and ReadForUpdateWithCtx +// cb is used to mock read data from DB +func MockReadForUpdateWithCtx(tableName string, cb func(data interface{}), err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadForUpdateWithCtx"), + []interface{}{err}, + func(inv *orm.Invocation) { + cb(inv.Args[0]) + }) +} + +// MockReadOrCreateWithCtx support ReadOrCreate and ReadOrCreateWithCtx +// cb is used to mock read data from DB +func MockReadOrCreateWithCtx(tableName string, + cb func(data interface{}), + insert bool, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "ReadOrCreateWithCtx"), + []interface{}{insert, id, err}, + func(inv *orm.Invocation) { + cb(inv.Args[0]) + }) +} + +// MockInsertWithCtx support Insert and InsertWithCtx +func MockInsertWithCtx(tableName string, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertWithCtx"), []interface{}{id, err}, nil) +} + +// MockInsertMultiWithCtx support InsertMulti and InsertMultiWithCtx +func MockInsertMultiWithCtx(tableName string, cnt int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertMultiWithCtx"), []interface{}{cnt, err}, nil) +} + +// MockInsertOrUpdateWithCtx support InsertOrUpdate and InsertOrUpdateWithCtx +func MockInsertOrUpdateWithCtx(tableName string, id int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "InsertOrUpdateWithCtx"), []interface{}{id, err}, nil) +} + +// MockUpdateWithCtx support UpdateWithCtx and Update +func MockUpdateWithCtx(tableName string, affectedRow int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "UpdateWithCtx"), []interface{}{affectedRow, err}, nil) +} + +// MockDeleteWithCtx support Delete and DeleteWithCtx +func MockDeleteWithCtx(tableName string, affectedRow int64, err error) *Mock { + return NewMock(NewSimpleCondition(tableName, "DeleteWithCtx"), []interface{}{affectedRow, err}, nil) +} + +// MockQueryM2MWithCtx support QueryM2MWithCtx and QueryM2M +// Now you may be need to use golang/mock to generate QueryM2M mock instance +// Or use DoNothingQueryM2Mer +// for example: +// +// post := Post{Id: 4} +// m2m := Ormer.QueryM2M(&post, "Tags") +// +// when you write test code: +// MockQueryM2MWithCtx("post", "Tags", mockM2Mer) +// "post" is the table name of model Post structure +// TODO provide orm.QueryM2Mer +func MockQueryM2MWithCtx(tableName string, name string, res orm.QueryM2Mer) *Mock { + return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{res}, nil) +} + +// MockLoadRelatedWithCtx support LoadRelatedWithCtx and LoadRelated +func MockLoadRelatedWithCtx(tableName string, name string, rows int64, err error) *Mock { + return NewMock(NewQueryM2MerCondition(tableName, name), []interface{}{rows, err}, nil) +} + +// MockQueryTableWithCtx support QueryTableWithCtx and QueryTable +func MockQueryTableWithCtx(tableName string, qs orm.QuerySeter) *Mock { + return NewMock(NewSimpleCondition(tableName, "QueryTable"), []interface{}{qs}, nil) +} + +// MockRawWithCtx support RawWithCtx and Raw +func MockRawWithCtx(rs orm.RawSeter) *Mock { + return NewMock(NewSimpleCondition("", "RawWithCtx"), []interface{}{rs}, nil) +} + +// MockDriver support Driver +// func MockDriver(driver orm.Driver) *Mock { +// return NewMock(NewSimpleCondition("", "Driver"), []interface{}{driver}) +// } + +// MockDBStats support DBStats +func MockDBStats(stats *sql.DBStats) *Mock { + return NewMock(NewSimpleCondition("", "DBStats"), []interface{}{stats}, nil) +} + +// MockBeginWithCtxAndOpts support Begin, BeginWithCtx, BeginWithOpts, BeginWithCtxAndOpts +// func MockBeginWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock { +// return NewMock(NewSimpleCondition("", "BeginWithCtxAndOpts"), []interface{}{txOrm, err}) +// } + +// MockDoTxWithCtxAndOpts support DoTx, DoTxWithCtx, DoTxWithOpts, DoTxWithCtxAndOpts +// func MockDoTxWithCtxAndOpts(txOrm *orm.TxOrmer, err error) *Mock { +// return MockBeginWithCtxAndOpts(txOrm, err) +// } + +// MockCommit support Commit +func MockCommit(err error) *Mock { + return NewMock(NewSimpleCondition("", "Commit"), []interface{}{err}, nil) +} + +// MockRollback support Rollback +func MockRollback(err error) *Mock { + return NewMock(NewSimpleCondition("", "Rollback"), []interface{}{err}, nil) +} + +// MockRollbackUnlessCommit support RollbackUnlessCommit +func MockRollbackUnlessCommit(err error) *Mock { + return NewMock(NewSimpleCondition("", "RollbackUnlessCommit"), []interface{}{err}, nil) +} diff --git a/client/orm/mock/mock_orm_test.go b/client/orm/mock/mock_orm_test.go new file mode 100644 index 0000000000..29d9ea0197 --- /dev/null +++ b/client/orm/mock/mock_orm_test.go @@ -0,0 +1,312 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "database/sql" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +const mockErrorMsg = "mock error" + +func init() { + orm.RegisterModel(&User{}) +} + +func TestMockDBStats(t *testing.T) { + s := StartMock() + defer s.Clear() + stats := &sql.DBStats{} + s.Mock(MockDBStats(stats)) + + o := orm.NewOrm() + + res := o.DBStats() + + assert.Equal(t, stats, res) +} + +func TestMockDeleteWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + s.Mock(MockDeleteWithCtx((&User{}).TableName(), 12, nil)) + o := orm.NewOrm() + rows, err := o.Delete(&User{}) + assert.Equal(t, int64(12), rows) + assert.Nil(t, err) +} + +func TestMockInsertOrUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + s.Mock(MockInsertOrUpdateWithCtx((&User{}).TableName(), 12, nil)) + o := orm.NewOrm() + id, err := o.InsertOrUpdate(&User{}) + assert.Equal(t, int64(12), id) + assert.Nil(t, err) +} + +func TestMockRead(t *testing.T) { + s := StartMock() + defer s.Clear() + err := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, err)) + o := orm.NewOrm() + u := &User{} + e := o.Read(u) + assert.Equal(t, err, e) + assert.Equal(t, "Tom", u.Name) +} + +func TestMockQueryM2MWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingQueryM2Mer{} + s.Mock(MockQueryM2MWithCtx((&User{}).TableName(), "Tags", mock)) + o := orm.NewOrm() + res := o.QueryM2M(&User{}, "Tags") + assert.Equal(t, mock, res) +} + +func TestMockQueryTableWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingQuerySetter{} + s.Mock(MockQueryTableWithCtx((&User{}).TableName(), mock)) + o := orm.NewOrm() + res := o.QueryTable(&User{}) + assert.Equal(t, mock, res) +} + +func TestMockTable(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockTable((&User{}).TableName(), mock)) + o := orm.NewOrm() + res := o.Read(&User{}) + assert.Equal(t, mock, res) +} + +func TestMockInsertMultiWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockInsertMultiWithCtx((&User{}).TableName(), 12, mock)) + o := orm.NewOrm() + res, err := o.InsertMulti(11, []interface{}{&User{}}) + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockInsertWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockInsertWithCtx((&User{}).TableName(), 13, mock)) + o := orm.NewOrm() + res, err := o.Insert(&User{}) + assert.Equal(t, int64(13), res) + assert.Equal(t, mock, err) +} + +func TestMockUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockUpdateWithCtx((&User{}).TableName(), 12, mock)) + o := orm.NewOrm() + res, err := o.Update(&User{}) + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockLoadRelatedWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockLoadRelatedWithCtx((&User{}).TableName(), "T", 12, mock)) + o := orm.NewOrm() + res, err := o.LoadRelated(&User{}, "T") + assert.Equal(t, int64(12), res) + assert.Equal(t, mock, err) +} + +func TestMockMethod(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockMethod("ReadWithCtx", mock)) + o := orm.NewOrm() + err := o.Read(&User{}) + assert.Equal(t, mock, err) +} + +func TestMockReadForUpdateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockReadForUpdateWithCtx((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + o := orm.NewOrm() + u := &User{} + err := o.ReadForUpdate(u) + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestMockRawWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := &DoNothingRawSetter{} + s.Mock(MockRawWithCtx(mock)) + o := orm.NewOrm() + res := o.Raw("") + assert.Equal(t, mock, res) +} + +func TestMockReadOrCreateWithCtx(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockReadOrCreateWithCtx((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, false, 12, mock)) + o := orm.NewOrm() + u := &User{} + inserted, id, err := o.ReadOrCreate(u, "") + assert.Equal(t, mock, err) + assert.Equal(t, int64(12), id) + assert.False(t, inserted) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionClosure(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + u, err := originalTxUsingClosure() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionManually(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, mock)) + u, err := originalTxManually() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func TestTransactionRollback(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), nil, errors.New("read error"))) + s.Mock(MockRollback(mock)) + _, err := originalTx() + assert.Equal(t, mock, err) +} + +func TestTransactionRollbackUnlessCommit(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRollbackUnlessCommit(mock)) + + // u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.RollbackUnlessCommit() + assert.Equal(t, mock, err) +} + +func TestTransactionCommit(t *testing.T) { + s := StartMock() + defer s.Clear() + mock := errors.New(mockErrorMsg) + s.Mock(MockRead((&User{}).TableName(), func(data interface{}) { + u := data.(*User) + u.Name = "Tom" + }, nil)) + s.Mock(MockCommit(mock)) + u, err := originalTx() + assert.Equal(t, mock, err) + assert.Equal(t, "Tom", u.Name) +} + +func originalTx() (*User, error) { + u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.Read(u) + if err == nil { + err = txOrm.Commit() + return u, err + } else { + err = txOrm.Rollback() + return nil, err + } +} + +func originalTxManually() (*User, error) { + u := &User{} + o := orm.NewOrm() + txOrm, _ := o.Begin() + err := txOrm.Read(u) + _ = txOrm.Commit() + return u, err +} + +func originalTxUsingClosure() (*User, error) { + u := &User{} + var err error + o := orm.NewOrm() + _ = o.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error { + err = txOrm.Read(u) + return nil + }) + return u, err +} + +type User struct { + Id int + Name string +} + +func (u *User) TableName() string { + return "user" +} diff --git a/client/orm/mock/mock_queryM2Mer.go b/client/orm/mock/mock_queryM2Mer.go new file mode 100644 index 0000000000..732e27b68c --- /dev/null +++ b/client/orm/mock/mock_queryM2Mer.go @@ -0,0 +1,88 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" +) + +// DoNothingQueryM2Mer do nothing +// use it to build mock orm.QueryM2Mer +type DoNothingQueryM2Mer struct{} + +func (d *DoNothingQueryM2Mer) AddWithCtx(ctx context.Context, i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) RemoveWithCtx(ctx context.Context, i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) ExistWithCtx(ctx context.Context, i interface{}) bool { + return true +} + +func (d *DoNothingQueryM2Mer) ClearWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) CountWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Add(i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Remove(i ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Exist(i interface{}) bool { + return true +} + +func (d *DoNothingQueryM2Mer) Clear() (int64, error) { + return 0, nil +} + +func (d *DoNothingQueryM2Mer) Count() (int64, error) { + return 0, nil +} + +type QueryM2MerCondition struct { + tableName string + name string +} + +func NewQueryM2MerCondition(tableName string, name string) *QueryM2MerCondition { + return &QueryM2MerCondition{ + tableName: tableName, + name: name, + } +} + +func (q *QueryM2MerCondition) Match(ctx context.Context, inv *orm.Invocation) bool { + res := true + if len(q.tableName) > 0 { + res = res && (q.tableName == inv.GetTableName()) + } + if len(q.name) > 0 { + res = res && (len(inv.Args) > 1) && (q.name == inv.Args[1].(string)) + } + return res +} diff --git a/client/orm/mock/mock_queryM2Mer_test.go b/client/orm/mock/mock_queryM2Mer_test.go new file mode 100644 index 0000000000..ef754092b7 --- /dev/null +++ b/client/orm/mock/mock_queryM2Mer_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestDoNothingQueryM2Mer(t *testing.T) { + m2m := &DoNothingQueryM2Mer{} + + i, err := m2m.Clear() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Count() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Add() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = m2m.Remove() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + assert.True(t, m2m.Exist(nil)) +} + +func TestNewQueryM2MerCondition(t *testing.T) { + cond := NewQueryM2MerCondition("", "") + res := cond.Match(context.Background(), &orm.Invocation{}) + assert.True(t, res) + cond = NewQueryM2MerCondition("hello", "") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{})) + + cond = NewQueryM2MerCondition("", "A") + assert.False(t, cond.Match(context.Background(), &orm.Invocation{ + Args: []interface{}{0, "B"}, + })) + + assert.True(t, cond.Match(context.Background(), &orm.Invocation{ + Args: []interface{}{0, "A"}, + })) +} diff --git a/client/orm/mock/mock_querySetter.go b/client/orm/mock/mock_querySetter.go new file mode 100644 index 0000000000..5bbf988819 --- /dev/null +++ b/client/orm/mock/mock_querySetter.go @@ -0,0 +1,182 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/client/orm/clauses/order_clause" +) + +// DoNothingQuerySetter do nothing +// usually you use this to build your mock QuerySetter +type DoNothingQuerySetter struct{} + +func (d *DoNothingQuerySetter) OrderClauses(orders ...*order_clause.Order) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) CountWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ExistWithCtx(ctx context.Context) bool { + return true +} + +func (d *DoNothingQuerySetter) UpdateWithCtx(ctx context.Context, values orm.Params) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) DeleteWithCtx(ctx context.Context) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) PrepareInsertWithCtx(ctx context.Context) (orm.Inserter, error) { + return nil, nil +} + +func (d *DoNothingQuerySetter) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingQuerySetter) ValuesWithCtx(ctx context.Context, results *[]orm.Params, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesListWithCtx(ctx context.Context, results *[]orm.ParamsList, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesFlatWithCtx(ctx context.Context, result *orm.ParamsList, expr string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Aggregate(s string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Filter(s string, i ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) FilterRaw(s string, s2 string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Exclude(s string, i ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) SetCond(condition *orm.Condition) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) GetCond() *orm.Condition { + return orm.NewCondition() +} + +func (d *DoNothingQuerySetter) Limit(limit interface{}, args ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Offset(offset interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) GroupBy(exprs ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) OrderBy(exprs ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) ForceIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) UseIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) IgnoreIndex(indexes ...string) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) RelatedSel(params ...interface{}) orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Distinct() orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) ForUpdate() orm.QuerySeter { + return d +} + +func (d *DoNothingQuerySetter) Count() (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Exist() bool { + return true +} + +func (d *DoNothingQuerySetter) Update(values orm.Params) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) Delete() (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) PrepareInsert() (orm.Inserter, error) { + return nil, nil +} + +func (d *DoNothingQuerySetter) All(container interface{}, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) One(container interface{}, cols ...string) error { + return nil +} + +func (d *DoNothingQuerySetter) Values(results *[]orm.Params, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesList(results *[]orm.ParamsList, exprs ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) ValuesFlat(result *orm.ParamsList, expr string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingQuerySetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return 0, nil +} diff --git a/client/orm/mock/mock_querySetter_test.go b/client/orm/mock/mock_querySetter_test.go new file mode 100644 index 0000000000..09e5ad8c54 --- /dev/null +++ b/client/orm/mock/mock_querySetter_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingQuerySetter(t *testing.T) { + setter := &DoNothingQuerySetter{} + setter.GroupBy().Filter("").Limit(10). + Distinct().Exclude("a").FilterRaw("", ""). + ForceIndex().ForUpdate().IgnoreIndex(). + Offset(11).OrderBy().RelatedSel().SetCond(nil).UseIndex() + + assert.True(t, setter.Exist()) + err := setter.One(nil) + assert.Nil(t, err) + i, err := setter.Count() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Delete() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.All(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Update(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.RowsToMap(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.RowsToStruct(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.Values(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.ValuesFlat(nil, "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = setter.ValuesList(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + ins, err := setter.PrepareInsert() + assert.Nil(t, err) + assert.Nil(t, ins) + + assert.NotNil(t, setter.GetCond()) +} diff --git a/client/orm/mock/mock_rawSetter.go b/client/orm/mock/mock_rawSetter.go new file mode 100644 index 0000000000..f4768cc856 --- /dev/null +++ b/client/orm/mock/mock_rawSetter.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "database/sql" + + "github.com/beego/beego/v2/client/orm" +) + +type DoNothingRawSetter struct{} + +func (d *DoNothingRawSetter) Exec() (sql.Result, error) { + return nil, nil +} + +func (d *DoNothingRawSetter) QueryRow(containers ...interface{}) error { + return nil +} + +func (d *DoNothingRawSetter) QueryRows(containers ...interface{}) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) SetArgs(i ...interface{}) orm.RawSeter { + return d +} + +func (d *DoNothingRawSetter) Values(container *[]orm.Params, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) ValuesList(container *[]orm.ParamsList, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) ValuesFlat(container *orm.ParamsList, cols ...string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) RowsToMap(result *orm.Params, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + return 0, nil +} + +func (d *DoNothingRawSetter) Prepare() (orm.RawPreparer, error) { + return nil, nil +} diff --git a/client/orm/mock/mock_rawSetter_test.go b/client/orm/mock/mock_rawSetter_test.go new file mode 100644 index 0000000000..dd98edbd6e --- /dev/null +++ b/client/orm/mock/mock_rawSetter_test.go @@ -0,0 +1,63 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoNothingRawSetter(t *testing.T) { + rs := &DoNothingRawSetter{} + i, err := rs.ValuesList(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.Values(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.ValuesFlat(nil) + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.RowsToStruct(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.RowsToMap(nil, "", "") + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + i, err = rs.QueryRows() + assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + err = rs.QueryRow() + // assert.Equal(t, int64(0), i) + assert.Nil(t, err) + + s, err := rs.Exec() + assert.Nil(t, err) + assert.Nil(t, s) + + p, err := rs.Prepare() + assert.Nil(t, err) + assert.Nil(t, p) + + rrs := rs.SetArgs() + assert.Equal(t, rrs, rs) +} diff --git a/client/orm/mock/mock_test.go b/client/orm/mock/mock_test.go new file mode 100644 index 0000000000..73bce4e50e --- /dev/null +++ b/client/orm/mock/mock_test.go @@ -0,0 +1,58 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestOrmStub_FilterChain(t *testing.T) { + os := newOrmStub() + inv := &orm.Invocation{ + Args: []interface{}{10}, + } + i := 1 + os.FilterChain(func(ctx context.Context, inv *orm.Invocation) []interface{} { + i++ + return nil + })(context.Background(), inv) + + assert.Equal(t, 2, i) + + m := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) { + arg := inv.Args[0] + j := arg.(int) + inv.Args[0] = j + 1 + }) + os.Mock(m) + + os.FilterChain(nil)(context.Background(), inv) + assert.Equal(t, 11, inv.Args[0]) + + inv.Args[0] = 10 + ctxMock := NewMock(NewSimpleCondition("", ""), nil, func(inv *orm.Invocation) { + arg := inv.Args[0] + j := arg.(int) + inv.Args[0] = j + 3 + }) + + os.FilterChain(nil)(CtxWithMock(context.Background(), ctxMock), inv) + assert.Equal(t, 13, inv.Args[0]) +} diff --git a/client/orm/model_utils_test.go b/client/orm/model_utils_test.go new file mode 100644 index 0000000000..3e8b2042c1 --- /dev/null +++ b/client/orm/model_utils_test.go @@ -0,0 +1,15 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm diff --git a/client/orm/models_boot.go b/client/orm/models_boot.go new file mode 100644 index 0000000000..0a0f90d9a1 --- /dev/null +++ b/client/orm/models_boot.go @@ -0,0 +1,63 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "fmt" + "runtime/debug" + + imodels "github.com/beego/beego/v2/client/orm/internal/models" +) + +// RegisterModel Register models +func RegisterModel(models ...interface{}) { + RegisterModelWithPrefix("", models...) +} + +// RegisterModelWithPrefix Register models with a prefix +func RegisterModelWithPrefix(prefix string, models ...interface{}) { + if err := imodels.DefaultModelCache.Register(prefix, true, models...); err != nil { + panic(err) + } +} + +// RegisterModelWithSuffix Register models with a suffix +func RegisterModelWithSuffix(suffix string, models ...interface{}) { + if err := imodels.DefaultModelCache.Register(suffix, false, models...); err != nil { + panic(err) + } +} + +// BootStrap Bootstrap models. +// make All model parsed and can not add more models +func BootStrap() { + BootStrapWithDB("default") +} + +// BootStrapWithDB Bootstrap models with alias name. +func BootStrapWithDB(alias string) { + if _, ok := dataBaseCache.get(alias); !ok { + fmt.Printf("must have one Register DataBase alias named %q\n", alias) + debug.PrintStack() + return + } + imodels.DefaultModelCache.Bootstrap() +} + +// ResetModelCache Clean model cache. Then you can re-RegisterModel. +// Common use this api for test case. +func ResetModelCache() { + imodels.DefaultModelCache.Clean() +} diff --git a/client/orm/models_fields.go b/client/orm/models_fields.go new file mode 100644 index 0000000000..1fda9f1e70 --- /dev/null +++ b/client/orm/models_fields.go @@ -0,0 +1,173 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/beego/beego/v2/client/orm/internal/models" +) + +// Define the Type enum +const ( + TypeBooleanField = models.TypeBooleanField + TypeVarCharField = models.TypeVarCharField + TypeCharField = models.TypeCharField + TypeTextField = models.TypeTextField + TypeTimeField = models.TypeTimeField + TypeDateField = models.TypeDateField + TypeDateTimeField = models.TypeDateTimeField + TypeBitField = models.TypeBitField + TypeSmallIntegerField = models.TypeSmallIntegerField + TypeIntegerField = models.TypeIntegerField + TypeBigIntegerField = models.TypeBigIntegerField + TypePositiveBitField = models.TypePositiveBitField + TypePositiveSmallIntegerField = models.TypePositiveSmallIntegerField + TypePositiveIntegerField = models.TypePositiveIntegerField + TypePositiveBigIntegerField = models.TypePositiveBigIntegerField + TypeFloatField = models.TypeFloatField + TypeDecimalField = models.TypeDecimalField + TypeJSONField = models.TypeJSONField + TypeJsonbField = models.TypeJsonbField + RelForeignKey = models.RelForeignKey + RelOneToOne = models.RelOneToOne + RelManyToMany = models.RelManyToMany + RelReverseOne = models.RelReverseOne + RelReverseMany = models.RelReverseMany +) + +// Define some logic enum +const ( + IsIntegerField = models.IsIntegerField + IsPositiveIntegerField = models.IsPositiveIntegerField + IsRelField = models.IsRelField + IsFieldType = models.IsFieldType +) + +// BooleanField A true/false field. +type BooleanField = models.BooleanField + +// verify the BooleanField implement the Fielder interface +var _ Fielder = new(BooleanField) + +// CharField A string field +// required values tag: size +// The size is enforced at the database level and in models’s validation. +// eg: `orm:"size(120)"` +type CharField = models.CharField + +// verify CharField implement Fielder +var _ Fielder = new(CharField) + +// TimeField A time, represented in go by a time.Time instance. +// only time values like 10:00:00 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically Set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically Set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type TimeField = models.TimeField + +var _ Fielder = new(TimeField) + +// DateField A date, represented in go by a time.Time instance. +// only date values like 2006-01-02 +// Has a few extra, optional attr tag: +// +// auto_now: +// Automatically Set the field to now every time the object is saved. Useful for “last-modified” timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// auto_now_add: +// Automatically Set the field to now when the object is first created. Useful for creation of timestamps. +// Note that the current date is always used; it’s not just a default value that you can override. +// +// eg: `orm:"auto_now"` or `orm:"auto_now_add"` +type DateField = models.DateField + +// verify DateField implement fielder interface +var _ Fielder = new(DateField) + +// DateTimeField A date, represented in go by a time.Time instance. +// datetime values like 2006-01-02 15:04:05 +// Takes the same extra arguments as DateField. +type DateTimeField = models.DateTimeField + +// verify datetime implement fielder +var _ models.Fielder = new(DateTimeField) + +// FloatField A floating-point number represented in go by a float32 value. +type FloatField = models.FloatField + +// verify FloatField implement Fielder +var _ Fielder = new(FloatField) + +// SmallIntegerField -32768 to 32767 +type SmallIntegerField = models.SmallIntegerField + +// verify SmallIntegerField implement Fielder +var _ Fielder = new(SmallIntegerField) + +// IntegerField -2147483648 to 2147483647 +type IntegerField = models.IntegerField + +// verify IntegerField implement Fielder +var _ Fielder = new(IntegerField) + +// BigIntegerField -9223372036854775808 to 9223372036854775807. +type BigIntegerField = models.BigIntegerField + +// verify BigIntegerField implement Fielder +var _ Fielder = new(BigIntegerField) + +// PositiveSmallIntegerField 0 to 65535 +type PositiveSmallIntegerField = models.PositiveSmallIntegerField + +// verify PositiveSmallIntegerField implement Fielder +var _ Fielder = new(PositiveSmallIntegerField) + +// PositiveIntegerField 0 to 4294967295 +type PositiveIntegerField = models.PositiveIntegerField + +// verify PositiveIntegerField implement Fielder +var _ Fielder = new(PositiveIntegerField) + +// PositiveBigIntegerField 0 to 18446744073709551615 +type PositiveBigIntegerField = models.PositiveBigIntegerField + +// verify PositiveBigIntegerField implement Fielder +var _ Fielder = new(PositiveBigIntegerField) + +// TextField A large text field. +type TextField = models.TextField + +// verify TextField implement Fielder +var _ Fielder = new(TextField) + +// JSONField postgres json field. +type JSONField = models.JSONField + +// verify JSONField implement Fielder +var _ models.Fielder = new(JSONField) + +// JsonbField postgres json field. +type JsonbField = models.JsonbField + +// verify JsonbField implement Fielder +var _ models.Fielder = new(JsonbField) diff --git a/orm/models_test.go b/client/orm/models_test.go similarity index 53% rename from orm/models_test.go rename to client/orm/models_test.go index e3a635f2d2..2684e0273d 100644 --- a/orm/models_test.go +++ b/client/orm/models_test.go @@ -22,11 +22,11 @@ import ( "strings" "time" + "github.com/beego/beego/v2/client/orm/internal/models" + _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - // As tidb can't use go get, so disable the tidb testing now - // _ "github.com/pingcap/tidb" ) // A slice string field. @@ -53,18 +53,24 @@ func (e *SliceStringField) FieldType() int { } func (e *SliceStringField) SetRaw(value interface{}) error { - switch d := value.(type) { - case []string: - e.Set(d) - case string: - if len(d) > 0 { - parts := strings.Split(d, ",") + f := func(str string) { + if len(str) > 0 { + parts := strings.Split(str, ",") v := make([]string, 0, len(parts)) for _, p := range parts { v = append(v, strings.TrimSpace(p)) } e.Set(v) } + } + + switch d := value.(type) { + case []string: + e.Set(d) + case string: + f(d) + case []byte: + f(string(d)) default: return fmt.Errorf(" unknown value `%v`", value) } @@ -75,7 +81,7 @@ func (e *SliceStringField) RawValue() interface{} { return e.String() } -var _ Fielder = new(SliceStringField) +var _ models.Fielder = new(SliceStringField) // A json field. type JSONFieldTest struct { @@ -96,6 +102,8 @@ func (e *JSONFieldTest) SetRaw(value interface{}) error { switch d := value.(type) { case string: return json.Unmarshal([]byte(d), e) + case []byte: + return json.Unmarshal(d, e) default: return fmt.Errorf(" unknown value `%v`", value) } @@ -105,11 +113,15 @@ func (e *JSONFieldTest) RawValue() interface{} { return e.String() } -var _ Fielder = new(JSONFieldTest) +var _ models.Fielder = new(JSONFieldTest) type Data struct { ID int `orm:"column(id)"` Boolean bool + Byte byte + Int8 int8 + Uint8 uint8 + Rune rune Char string `orm:"size(50)"` Text string `orm:"type(text)"` JSON string `orm:"type(json);default({\"name\":\"json\"})"` @@ -117,110 +129,109 @@ type Data struct { Time time.Time `orm:"type(time)"` Date time.Time `orm:"type(date)"` DateTime time.Time `orm:"column(datetime)"` - Byte byte - Rune rune Int int - Int8 int8 + Uint uint Int16 int16 + Uint16 uint16 Int32 int32 Int64 int64 - Uint uint - Uint8 uint8 - Uint16 uint16 Uint32 uint32 - Uint64 uint64 Float32 float32 + Uint64 uint64 Float64 float64 Decimal float64 `orm:"digits(8);decimals(4)"` } type DataNull struct { - ID int `orm:"column(id)"` - Boolean bool `orm:"null"` - Char string `orm:"null;size(50)"` - Text string `orm:"null;type(text)"` - JSON string `orm:"type(json);null"` - Jsonb string `orm:"type(jsonb);null"` - Time time.Time `orm:"null;type(time)"` - Date time.Time `orm:"null;type(date)"` - DateTime time.Time `orm:"null;column(datetime)"` - Byte byte `orm:"null"` - Rune rune `orm:"null"` - Int int `orm:"null"` - Int8 int8 `orm:"null"` - Int16 int16 `orm:"null"` - Int32 int32 `orm:"null"` - Int64 int64 `orm:"null"` - Uint uint `orm:"null"` - Uint8 uint8 `orm:"null"` - Uint16 uint16 `orm:"null"` - Uint32 uint32 `orm:"null"` - Uint64 uint64 `orm:"null"` - Float32 float32 `orm:"null"` - Float64 float64 `orm:"null"` - Decimal float64 `orm:"digits(8);decimals(4);null"` - NullString sql.NullString `orm:"null"` - NullBool sql.NullBool `orm:"null"` - NullFloat64 sql.NullFloat64 `orm:"null"` - NullInt64 sql.NullInt64 `orm:"null"` - BooleanPtr *bool `orm:"null"` - CharPtr *string `orm:"null;size(50)"` - TextPtr *string `orm:"null;type(text)"` - BytePtr *byte `orm:"null"` - RunePtr *rune `orm:"null"` - IntPtr *int `orm:"null"` - Int8Ptr *int8 `orm:"null"` - Int16Ptr *int16 `orm:"null"` - Int32Ptr *int32 `orm:"null"` - Int64Ptr *int64 `orm:"null"` - UintPtr *uint `orm:"null"` - Uint8Ptr *uint8 `orm:"null"` - Uint16Ptr *uint16 `orm:"null"` - Uint32Ptr *uint32 `orm:"null"` - Uint64Ptr *uint64 `orm:"null"` - Float32Ptr *float32 `orm:"null"` - Float64Ptr *float64 `orm:"null"` - DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` - TimePtr *time.Time `orm:"null;type(time)"` - DatePtr *time.Time `orm:"null;type(date)"` - DateTimePtr *time.Time `orm:"null"` -} - -type String string -type Boolean bool -type Byte byte -type Rune rune -type Int int -type Int8 int8 -type Int16 int16 -type Int32 int32 -type Int64 int64 -type Uint uint -type Uint8 uint8 -type Uint16 uint16 -type Uint32 uint32 -type Uint64 uint64 -type Float32 float64 -type Float64 float64 + ID int `orm:"column(id)"` + Char string `orm:"null;size(50)"` + Text string `orm:"null;type(text)"` + JSON string `orm:"type(json);null"` + Jsonb string `orm:"type(jsonb);null"` + Time time.Time `orm:"null;type(time)"` + Date time.Time `orm:"null;type(date)"` + DateTime time.Time `orm:"null;column(datetime)"` + DateTimePrecision time.Time `orm:"null;type(datetime);precision(4)"` + Boolean bool `orm:"null"` + Byte byte `orm:"null"` + Int8 int8 `orm:"null"` + Uint8 uint8 `orm:"null"` + Rune rune `orm:"null"` + Int int `orm:"null"` + Uint uint `orm:"null"` + Int16 int16 `orm:"null"` + Uint16 uint16 `orm:"null"` + Int32 int32 `orm:"null"` + Int64 int64 `orm:"null"` + Uint32 uint32 `orm:"null"` + Float32 float32 `orm:"null"` + Uint64 uint64 `orm:"null"` + Float64 float64 `orm:"null"` + Decimal float64 `orm:"digits(8);decimals(4);null"` + NullString sql.NullString `orm:"null"` + NullBool sql.NullBool `orm:"null"` + NullFloat64 sql.NullFloat64 `orm:"null"` + NullInt64 sql.NullInt64 `orm:"null"` + BooleanPtr *bool `orm:"null"` + CharPtr *string `orm:"null;size(50)"` + TextPtr *string `orm:"null;type(text)"` + BytePtr *byte `orm:"null"` + RunePtr *rune `orm:"null"` + IntPtr *int `orm:"null"` + Int8Ptr *int8 `orm:"null"` + Int16Ptr *int16 `orm:"null"` + Int32Ptr *int32 `orm:"null"` + Int64Ptr *int64 `orm:"null"` + UintPtr *uint `orm:"null"` + Uint8Ptr *uint8 `orm:"null"` + Uint16Ptr *uint16 `orm:"null"` + Uint32Ptr *uint32 `orm:"null"` + Uint64Ptr *uint64 `orm:"null"` + Float32Ptr *float32 `orm:"null"` + Float64Ptr *float64 `orm:"null"` + DecimalPtr *float64 `orm:"digits(8);decimals(4);null"` + TimePtr *time.Time `orm:"null;type(time)"` + DatePtr *time.Time `orm:"null;type(date)"` + DateTimePtr *time.Time `orm:"null"` +} + +type ( + String string + Boolean bool + Byte byte + Rune rune + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint uint + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Float32 float64 + Float64 float64 +) type DataCustom struct { ID int `orm:"column(id)"` Boolean Boolean - Char string `orm:"size(50)"` - Text string `orm:"type(text)"` Byte Byte + Int8 Int8 + Uint8 Uint8 Rune Rune + Char string `orm:"size(50)"` + Text string `orm:"type(text)"` Int Int - Int8 Int8 + Uint Uint Int16 Int16 + Uint16 Uint16 Int32 Int32 Int64 Int64 - Uint Uint - Uint8 Uint8 - Uint16 Uint16 Uint32 Uint32 - Uint64 Uint64 Float32 Float32 + Uint64 Uint64 Float64 Float64 Decimal Float64 `orm:"digits(8);decimals(4)"` } @@ -231,24 +242,55 @@ type UserBig struct { Name string } +type TM struct { + ID int `orm:"column(id)"` + TMPrecision1 time.Time `orm:"type(datetime);precision(3)"` + TMPrecision2 time.Time `orm:"auto_now_add;type(datetime);precision(4)"` +} + +func (t *TM) TableName() string { + return "tm" +} + +func NewTM() *TM { + obj := new(TM) + return obj +} + +type DeptInfo struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + +type UnregisterModel struct { + ID int `orm:"column(id)"` + Created time.Time `orm:"auto_now_add"` + DeptName string + EmployeeName string + Salary int +} + type User struct { - ID int `orm:"column(id)"` - UserName string `orm:"size(30);unique"` - Email string `orm:"size(100)"` - Password string `orm:"size(100)"` - Status int16 `orm:"column(Status)"` - IsStaff bool - IsActive bool `orm:"default(true)"` - Created time.Time `orm:"auto_now_add;type(date)"` - Updated time.Time `orm:"auto_now"` - Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` - Posts []*Post `orm:"reverse(many)" json:"-"` - ShouldSkip string `orm:"-"` - Nums int - Langs SliceStringField `orm:"size(100)"` - Extra JSONFieldTest `orm:"type(text)"` - unexport bool `orm:"-"` - unexportBool bool + ID int `orm:"column(id)"` + UserName string `orm:"size(30);unique"` + Email string `orm:"size(100)"` + Password string `orm:"size(100)"` + Status int16 `orm:"column(Status)"` + IsStaff bool + IsActive bool `orm:"default(true)"` + Unexported bool `orm:"-"` + UnexportedBool bool + Created time.Time `orm:"auto_now_add;type(date)"` + Updated time.Time `orm:"auto_now"` + Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` + Posts []*Post `orm:"reverse(many)" json:"-"` + ShouldSkip string `orm:"-"` + Nums int + Langs SliceStringField `orm:"size(100)"` + Extra JSONFieldTest `orm:"type(text)"` } func (u *User) TableIndex() [][]string { @@ -287,13 +329,14 @@ func NewProfile() *Profile { } type Post struct { - ID int `orm:"column(id)"` - User *User `orm:"rel(fk)"` - Title string `orm:"size(60)"` - Content string `orm:"type(text)"` - Created time.Time `orm:"auto_now_add"` - Updated time.Time `orm:"auto_now"` - Tags []*Tag `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.PostTags)"` + ID int `orm:"column(id)"` + User *User `orm:"rel(fk)"` + Title string `orm:"size(60)"` + Content string `orm:"type(text)"` + Created time.Time `orm:"auto_now_add"` + Updated time.Time `orm:"auto_now"` + UpdatedPrecision time.Time `orm:"auto_now;type(datetime);precision(4)"` + Tags []*Tag `orm:"rel(m2m);rel_through(github.com/beego/beego/v2/client/orm.PostTags)"` } func (u *Post) TableIndex() [][]string { @@ -307,6 +350,11 @@ func NewPost() *Post { return obj } +type NullValue struct { + ID int `orm:"column(id)"` + Value string `orm:"size(30);null"` +} + type Tag struct { ID int `orm:"column(id)"` Name string `orm:"size(30)"` @@ -351,7 +399,7 @@ type Group struct { type Permission struct { ID int `orm:"column(id)"` Name string - Groups []*Group `orm:"rel(m2m);rel_through(github.com/astaxie/beego/orm.GroupPermissions)"` + Groups []*Group `orm:"rel(m2m);rel_through(github.com/beego/beego/v2/client/orm.GroupPermissions)"` } type GroupPermissions struct { @@ -380,6 +428,15 @@ type InLine struct { Email string } +type Index struct { + // Common Fields + Id int `orm:"column(id)"` + + // Other Fields + F1 int `orm:"column(f1);index"` + F2 int `orm:"column(f2);index"` +} + func NewInLine() *InLine { return new(InLine) } @@ -411,6 +468,11 @@ type PtrPk struct { Positive bool } +type StrPk struct { + Id string `orm:"column(id);size(64);pk"` + Value string +} + var DBARGS = struct { Driver string Source string @@ -433,65 +495,66 @@ var ( dDbBaser dbBaser ) -var ( - helpinfo = `need driver and source! +var helpinfo = `need driver and source! Default DB Drivers. - + driver: url mysql: https://github.com/go-sql-driver/mysql sqlite3: https://github.com/mattn/go-sqlite3 postgres: https://github.com/lib/pq tidb: https://github.com/pingcap/tidb - + usage: - - go get -u github.com/astaxie/beego/orm - go get -u github.com/go-sql-driver/mysql - go get -u github.com/mattn/go-sqlite3 - go get -u github.com/lib/pq - go get -u github.com/pingcap/tidb - + + go Get -u github.com/beego/beego/v2/client/orm + go Get -u github.com/go-sql-driver/mysql + go Get -u github.com/mattn/go-sqlite3 + go Get -u github.com/lib/pq + go Get -u github.com/pingcap/tidb + #### MySQL mysql -u root -e 'create database orm_test;' export ORM_DRIVER=mysql export ORM_SOURCE="root:@/orm_test?charset=utf8" - go test -v github.com/astaxie/beego/orm - - + go test -v github.com/beego/beego/v2/client/orm + + #### Sqlite3 export ORM_DRIVER=sqlite3 export ORM_SOURCE='file:memory_test?mode=memory' - go test -v github.com/astaxie/beego/orm - - + go test -v github.com/beego/beego/v2/client/orm + + #### PostgreSQL psql -c 'create database orm_test;' -U postgres export ORM_DRIVER=postgres export ORM_SOURCE="user=postgres dbname=orm_test sslmode=disable" - go test -v github.com/astaxie/beego/orm - + go test -v github.com/beego/beego/v2/client/orm + #### TiDB export ORM_DRIVER=tidb export ORM_SOURCE='memory://test/test' - go test -v github.com/astaxie/beego/orm - + go test -v github.com/beego/beego/v2/pgk/orm + ` -) func init() { - Debug, _ = StrTo(DBARGS.Debug).Bool() + // Debug, _ = StrTo(DBARGS.Debug).Bool() + Debug = true if DBARGS.Driver == "" || DBARGS.Source == "" { fmt.Println(helpinfo) os.Exit(2) } - RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20) - - alias := getDbAlias("default") - if alias.Driver == DRMySQL { - alias.Engine = "INNODB" + err := RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, MaxIdleConnections(20)) + if err != nil { + panic(fmt.Sprintf("can not Register database: %v", err)) } + alias := getDB("default") + if alias.driver == DRMySQL { + alias.engine = "INNODB" + } } diff --git a/client/orm/orm.go b/client/orm/orm.go new file mode 100644 index 0000000000..1d5e399cd7 --- /dev/null +++ b/client/orm/orm.go @@ -0,0 +1,669 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package orm provide ORM for MySQL/PostgreSQL/sqlite +// Simple Usage +// +// package main +// +// import ( +// "fmt" +// "github.com/beego/beego/v2/client/orm" +// _ "github.com/go-sql-driver/mysql" // import your used driver +// ) +// +// // Model Struct +// type User struct { +// Id int `orm:"auto"` +// Name string `orm:"size(100)"` +// } +// +// func init() { +// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) +// } +// +// func main() { +// o := orm.NewOrm() +// user := User{Name: "slene"} +// // insert +// id, err := o.Insert(&user) +// // update +// user.name = "astaxie" +// num, err := o.Update(&user) +// // read one +// u := User{Id: user.Id} +// err = o.Read(&u) +// // delete +// num, err = o.Delete(&u) +// } +package orm + +import ( + "context" + "database/sql" + "errors" + "fmt" + "os" + "reflect" + + iutils "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/core/utils" +) + +// DebugQueries define the debug +const ( + DebugQueries = iota +) + +// Define common vars +var ( + Debug = false + DebugLog = NewLog(os.Stdout) + DefaultRowsLimit = -1 + DefaultRelsDepth = 2 + DefaultTimeLoc = iutils.DefaultTimeLoc + ErrTxDone = errors.New(" transaction already done") + ErrMultiRows = errors.New(" return multi rows") + ErrNoRows = errors.New(" no row found") + ErrStmtClosed = errors.New(" stmt already closed") + ErrArgs = errors.New(" args error may be empty") + ErrNotImplement = errors.New("have not implement") + + ErrLastInsertIdUnavailable = errors.New(" last insert id is unavailable") +) + +// Params stores the Params +type Params map[string]interface{} + +// ParamsList stores paramslist +type ParamsList []interface{} + +type ormBase struct { + alias *DB + db dbQuerier +} + +var ( + _ DQL = new(ormBase) + _ DML = new(ormBase) + _ DriverGetter = new(ormBase) +) + +// Get model info and model reflect value +func (*ormBase) getMi(md interface{}) (mi *models.ModelInfo) { + val := reflect.ValueOf(md) + ind := reflect.Indirect(val) + typ := ind.Type() + mi = getTypeMi(typ) + return +} + +// Get need ptr model info and model reflect value +func (*ormBase) getPtrMiInd(md interface{}) (mi *models.ModelInfo, ind reflect.Value) { + val := reflect.ValueOf(md) + ind = reflect.Indirect(val) + typ := ind.Type() + if val.Kind() != reflect.Ptr { + panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", models.GetFullName(typ))) + } + mi = getTypeMi(typ) + return +} + +func getTypeMi(mdTyp reflect.Type) *models.ModelInfo { + name := models.GetFullName(mdTyp) + if mi, ok := models.DefaultModelCache.GetByFullName(name); ok { + return mi + } + panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) +} + +// Get field info from model info by given field name +func (*ormBase) getFieldInfo(mi *models.ModelInfo, name string) *models.FieldInfo { + fi, ok := mi.Fields.GetByAny(name) + if !ok { + panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.FullName)) + } + return fi +} + +// read data to model +func (o *ormBase) Read(md interface{}, cols ...string) error { + return o.ReadWithCtx(context.Background(), md, cols...) +} + +func (o *ormBase) ReadRaw(ctx context.Context, md interface{}, query string, args ...any) error { + mi, ind := o.getPtrMiInd(md) + return o.alias.dbBaser.ReadRaw(ctx, o.db, mi, ind, o.alias.tz, query, args...) +} +func (o *ormBase) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error { + mi, ind := o.getPtrMiInd(md) + return o.alias.dbBaser.Read(ctx, o.db, mi, ind, o.alias.tz, cols, false) +} + +// read data to model, like Read(), but use "SELECT FOR UPDATE" form +func (o *ormBase) ReadForUpdate(md interface{}, cols ...string) error { + return o.ReadForUpdateWithCtx(context.Background(), md, cols...) +} + +func (o *ormBase) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error { + mi, ind := o.getPtrMiInd(md) + return o.alias.dbBaser.Read(ctx, o.db, mi, ind, o.alias.tz, cols, true) +} + +// Try to read a row from the database, or insert one if it doesn't exist +func (o *ormBase) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { + return o.ReadOrCreateWithCtx(context.Background(), md, col1, cols...) +} + +func (o *ormBase) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) { + cols = append([]string{col1}, cols...) + mi, ind := o.getPtrMiInd(md) + err := o.alias.dbBaser.Read(ctx, o.db, mi, ind, o.alias.tz, cols, false) + if err == ErrNoRows { + // Create + id, err := o.InsertWithCtx(ctx, md) + return err == nil, id, err + } + + id, vid := int64(0), ind.FieldByIndex(mi.Fields.Pk.FieldIndex) + if mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 { + id = int64(vid.Uint()) + } else if mi.Fields.Pk.Rel { + return o.ReadOrCreateWithCtx(ctx, vid.Interface(), mi.Fields.Pk.RelModelInfo.Fields.Pk.Name) + } else { + id = vid.Int() + } + + return false, id, err +} + +// insert model data to database +func (o *ormBase) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} + +func (o *ormBase) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { + mi, ind := o.getPtrMiInd(md) + id, err := o.alias.dbBaser.Insert(ctx, o.db, mi, ind, o.alias.tz) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + +// Set auto pk field +func (*ormBase) setPk(mi *models.ModelInfo, ind reflect.Value, id int64) { + if mi.Fields.Pk != nil && mi.Fields.Pk.Auto { + if mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(mi.Fields.Pk.FieldIndex).SetUint(uint64(id)) + } else { + ind.FieldByIndex(mi.Fields.Pk.FieldIndex).SetInt(id) + } + } +} + +// insert some models to database +func (o *ormBase) InsertMulti(bulk int, mds interface{}) (int64, error) { + return o.InsertMultiWithCtx(context.Background(), bulk, mds) +} + +func (o *ormBase) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) { + var cnt int64 + + sind := reflect.Indirect(reflect.ValueOf(mds)) + + switch sind.Kind() { + case reflect.Array, reflect.Slice: + if sind.Len() == 0 { + return cnt, ErrArgs + } + default: + return cnt, ErrArgs + } + + if bulk <= 1 { + for i := 0; i < sind.Len(); i++ { + ind := reflect.Indirect(sind.Index(i)) + mi := o.getMi(ind.Interface()) + id, err := o.alias.dbBaser.Insert(ctx, o.db, mi, ind, o.alias.tz) + if err != nil { + return cnt, err + } + + o.setPk(mi, ind, id) + + cnt++ + } + } else { + mi := o.getMi(sind.Index(0).Interface()) + return o.alias.dbBaser.InsertMulti(ctx, o.db, mi, sind, bulk, o.alias.tz) + } + return cnt, nil +} + +// InsertOrUpdate data to database +func (o *ormBase) InsertOrUpdate(md interface{}, colConflictAndArgs ...string) (int64, error) { + return o.InsertOrUpdateWithCtx(context.Background(), md, colConflictAndArgs...) +} + +func (o *ormBase) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) { + mi, ind := o.getPtrMiInd(md) + id, err := o.alias.dbBaser.InsertOrUpdate(ctx, o.db, mi, ind, o.alias, colConflitAndArgs...) + if err != nil { + return id, err + } + + o.setPk(mi, ind, id) + + return id, nil +} + +// update model to database. +// cols Set the Columns those want to update. +func (o *ormBase) Update(md interface{}, cols ...string) (int64, error) { + return o.UpdateWithCtx(context.Background(), md, cols...) +} + +func (o *ormBase) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + mi, ind := o.getPtrMiInd(md) + return o.alias.dbBaser.Update(ctx, o.db, mi, ind, o.alias.tz, cols) +} + +// delete model in database +// cols shows the delete conditions values read from. default is pk +func (o *ormBase) Delete(md interface{}, cols ...string) (int64, error) { + return o.DeleteWithCtx(context.Background(), md, cols...) +} + +func (o *ormBase) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) { + mi, ind := o.getPtrMiInd(md) + num, err := o.alias.dbBaser.Delete(ctx, o.db, mi, ind, o.alias.tz, cols) + return num, err +} + +// create a models to models queryer +func (o *ormBase) QueryM2M(md interface{}, name string) QueryM2Mer { + mi, ind := o.getPtrMiInd(md) + fi := o.getFieldInfo(mi, name) + + switch { + case fi.FieldType == RelManyToMany: + case fi.FieldType == RelReverseMany && fi.ReverseFieldInfo.Mi.IsThrough: + default: + panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.Name, mi.FullName)) + } + + return newQueryM2M(md, o, mi, fi, ind) +} + +// NOTE: this method is deprecated, context parameter will not take effect. +func (o *ormBase) QueryM2MWithCtx(_ context.Context, md interface{}, name string) QueryM2Mer { + logs.Warn("QueryM2MWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QueryM2M as replacement please.") + return o.QueryM2M(md, name) +} + +// load related models to md model. +// args are limit, offset int and order string. +// +// example: +// +// orm.LoadRelated(post,"Tags") +// for _,tag := range post.Tags{...} +// +// make sure the relation is defined in model struct tags. +func (o *ormBase) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) { + return o.LoadRelatedWithCtx(context.Background(), md, name, args...) +} + +func (o *ormBase) LoadRelatedWithCtx(_ context.Context, md interface{}, name string, args ...utils.KV) (int64, error) { + _, fi, ind, qs := o.queryRelated(md, name) + + var relDepth int + var limit, offset int64 + var order string + + kvs := utils.NewKVs(args...) + kvs.IfContains(hints.KeyRelDepth, func(value interface{}) { + if v, ok := value.(bool); ok { + if v { + relDepth = DefaultRelsDepth + } + } else if v, ok := value.(int); ok { + relDepth = v + } + }).IfContains(hints.KeyLimit, func(value interface{}) { + if v, ok := value.(int64); ok { + limit = v + } + }).IfContains(hints.KeyOffset, func(value interface{}) { + if v, ok := value.(int64); ok { + offset = v + } + }).IfContains(hints.KeyOrderBy, func(value interface{}) { + if v, ok := value.(string); ok { + order = v + } + }) + + switch fi.FieldType { + case RelOneToOne, RelForeignKey, RelReverseOne: + limit = 1 + offset = 0 + } + + qs.limit = limit + qs.offset = offset + qs.relDepth = relDepth + + if len(order) > 0 { + qs.orders = order_clause.ParseOrder(order) + } + + find := ind.FieldByIndex(fi.FieldIndex) + + var nums int64 + var err error + switch fi.FieldType { + case RelOneToOne, RelForeignKey, RelReverseOne: + val := reflect.New(find.Type().Elem()) + container := val.Interface() + err = qs.One(container) + if err == nil { + find.Set(val) + nums = 1 + } + default: + nums, err = qs.All(find.Addr().Interface()) + } + + return nums, err +} + +// Get QuerySeter for related models to md model +func (o *ormBase) queryRelated(md interface{}, name string) (*models.ModelInfo, *models.FieldInfo, reflect.Value, *querySet) { + mi, ind := o.getPtrMiInd(md) + fi := o.getFieldInfo(mi, name) + + _, _, exist := getExistPk(mi, ind) + if !exist { + panic(ErrMissPK) + } + + var qs *querySet + + switch fi.FieldType { + case RelOneToOne, RelForeignKey, RelManyToMany: + if !fi.InModel { + break + } + qs = o.getRelQs(md, mi, fi) + case RelReverseOne, RelReverseMany: + if !fi.InModel { + break + } + qs = o.getReverseQs(md, mi, fi) + } + + if qs == nil { + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel/reverse field", md, name)) + } + + return mi, fi, ind, qs +} + +// Get reverse relation QuerySeter +func (o *ormBase) getReverseQs(md interface{}, mi *models.ModelInfo, fi *models.FieldInfo) *querySet { + switch fi.FieldType { + case RelReverseOne, RelReverseMany: + default: + panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.Name, mi.FullName)) + } + + var q *querySet + + if fi.FieldType == RelReverseMany && fi.ReverseFieldInfo.Mi.IsThrough { + q = newQuerySet(o, fi.RelModelInfo).(*querySet) + q.cond = NewCondition().And(fi.ReverseFieldInfoM2M.Column+ExprSep+fi.ReverseFieldInfo.Column, md) + } else { + q = newQuerySet(o, fi.ReverseFieldInfo.Mi).(*querySet) + q.cond = NewCondition().And(fi.ReverseFieldInfo.Column, md) + } + + return q +} + +// Get relation QuerySeter +func (o *ormBase) getRelQs(md interface{}, mi *models.ModelInfo, fi *models.FieldInfo) *querySet { + switch fi.FieldType { + case RelOneToOne, RelForeignKey, RelManyToMany: + default: + panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.Name, mi.FullName)) + } + + q := newQuerySet(o, fi.RelModelInfo).(*querySet) + q.cond = NewCondition() + + if fi.FieldType == RelManyToMany { + q.cond = q.cond.And(fi.ReverseFieldInfoM2M.Column+ExprSep+fi.ReverseFieldInfo.Column, md) + } else { + q.cond = q.cond.And(fi.ReverseFieldInfo.Column, md) + } + + return q +} + +// return a QuerySeter for table operations. +// table name can be string or struct. +// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), +func (o *ormBase) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { + var name string + if table, ok := ptrStructOrTableName.(string); ok { + name = models.NameStrategyMap[models.DefaultNameStrategy](table) + if mi, ok := models.DefaultModelCache.Get(name); ok { + qs = newQuerySet(o, mi) + } + } else { + name = models.GetFullName(iutils.IndirectType(reflect.TypeOf(ptrStructOrTableName))) + if mi, ok := models.DefaultModelCache.GetByFullName(name); ok { + qs = newQuerySet(o, mi) + } + } + if qs == nil { + panic(fmt.Errorf(" table name: `%s` not exists", name)) + } + return qs +} + +// Deprecated: QueryTableWithCtx is deprecated, context parameter will not take effect. +func (o *ormBase) QueryTableWithCtx(_ context.Context, ptrStructOrTableName interface{}) (qs QuerySeter) { + logs.Warn("QueryTableWithCtx is DEPRECATED. Use methods with `WithCtx` suffix on QuerySeter as replacement please.") + return o.QueryTable(ptrStructOrTableName) +} + +// Raw return a raw query seter for raw sql string. +func (o *ormBase) Raw(query string, args ...interface{}) RawSeter { + return o.RawWithCtx(context.Background(), query, args...) +} + +func (o *ormBase) RawWithCtx(_ context.Context, query string, args ...interface{}) RawSeter { + return newRawSet(o, query, args) +} + +func (o *ormBase) ExecRaw(ctx context.Context, _ interface{}, query string, args ...any) (sql.Result, error) { + return o.alias.dbBaser.ExecRaw(ctx, o.db, query, args...) +} + +// Driver return current using database Driver +func (o *ormBase) Driver() Driver { + return driver(o.alias.name) +} + +// DBStats return sql.dbStats for current database +func (o *ormBase) DBStats() *sql.DBStats { + if o.alias != nil && o.alias.DB != nil { + stats := o.alias.DB.Stats() + return &stats + } + return nil +} + +type orm struct { + ormBase +} + +var _ Ormer = new(orm) + +func (o *orm) Begin() (TxOrmer, error) { + return o.BeginWithCtx(context.Background()) +} + +func (o *orm) BeginWithCtx(ctx context.Context) (TxOrmer, error) { + return o.BeginWithCtxAndOpts(ctx, nil) +} + +func (o *orm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) { + return o.BeginWithCtxAndOpts(context.Background(), opts) +} + +func (o *orm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) { + tx, err := o.db.(txer).BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + _txOrm := &txOrm{ + ormBase: ormBase{ + alias: o.alias, + db: &TxDB{tx: tx}, + }, + } + + if Debug { + _txOrm.db = newDbQueryLog(o.alias, _txOrm.db) + } + + var taskTxOrm TxOrmer = _txOrm + return taskTxOrm, nil +} + +func (o *orm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error { + return o.DoTxWithCtx(context.Background(), task) +} + +func (o *orm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error { + return o.DoTxWithCtxAndOpts(ctx, nil, task) +} + +func (o *orm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { + return o.DoTxWithCtxAndOpts(context.Background(), opts, task) +} + +func (o *orm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error { + return doTxTemplate(ctx, o, opts, task) +} + +func doTxTemplate(ctx context.Context, o TxBeginner, opts *sql.TxOptions, + task func(ctx context.Context, txOrm TxOrmer) error) error { + _txOrm, err := o.BeginWithCtxAndOpts(ctx, opts) + if err != nil { + return err + } + panicked := true + defer func() { + if panicked || err != nil { + e := _txOrm.Rollback() + if e != nil { + logs.Error("rollback transaction failed: %v,%v", e, panicked) + } + } else { + e := _txOrm.Commit() + if e != nil { + logs.Error("commit transaction failed: %v,%v", e, panicked) + } + } + }() + taskTxOrm := _txOrm + err = task(ctx, taskTxOrm) + panicked = false + return err +} + +type txOrm struct { + ormBase +} + +var _ TxOrmer = new(txOrm) + +func (t *txOrm) Commit() error { + return t.db.(txEnder).Commit() +} + +func (t *txOrm) Rollback() error { + return t.db.(txEnder).Rollback() +} + +func (t *txOrm) RollbackUnlessCommit() error { + return t.db.(txEnder).RollbackUnlessCommit() +} + +// NewOrm create new orm +func NewOrm() Ormer { + return NewOrmUsingDB(`default`) +} + +// NewOrmUsingDB create new orm with the name +func NewOrmUsingDB(aliasName string) Ormer { + if al, ok := dataBaseCache.get(aliasName); ok { + return newOrmWithDB(al) + } + panic(fmt.Errorf(" unknown db alias name `%s`", aliasName)) +} + +// NewOrmWithDB create a new ormer object with specify *sql.db for query +func NewOrmWithDB(driverName, aliasName string, db *sql.DB, params ...DBOption) (Ormer, error) { + al, err := newDB(aliasName, driverName, db, params...) + if err != nil { + return nil, err + } + + return newOrmWithDB(al), nil +} + +func newOrmWithDB(al *DB) Ormer { + BootStrapWithDB(al.name) // execute only once + + o := new(orm) + o.alias = al + + if Debug { + o.db = newDbQueryLog(al, al.DB) + } else { + o.db = al.DB + } + + if len(globalFilterChains) > 0 { + return NewFilterOrmDecorator(o, globalFilterChains...) + } + return o +} diff --git a/orm/orm_conds.go b/client/orm/orm_conds.go similarity index 92% rename from orm/orm_conds.go rename to client/orm/orm_conds.go index f3fd66f0b1..9946e595ab 100644 --- a/orm/orm_conds.go +++ b/client/orm/orm_conds.go @@ -17,11 +17,13 @@ package orm import ( "fmt" "strings" + + "github.com/beego/beego/v2/client/orm/clauses" ) // ExprSep define the expression separation const ( - ExprSep = "__" + ExprSep = clauses.ExprSep ) type condValue struct { @@ -76,17 +78,19 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition { // AndCond combine a condition to current condition func (c *Condition) AndCond(cond *Condition) *Condition { - c = c.clone() if c == cond { panic(fmt.Errorf(" cannot use self as sub cond")) } + + c = c.clone() + if cond != nil { c.params = append(c.params, condValue{cond: cond, isCond: true}) } return c } -// AndNotCond combine a AND NOT condition to current condition +// AndNotCond combine an AND NOT condition to current condition func (c *Condition) AndNotCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -117,7 +121,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition { return &c } -// OrCond combine a OR condition to current condition +// OrCond combine an OR condition to current condition func (c *Condition) OrCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -129,7 +133,7 @@ func (c *Condition) OrCond(cond *Condition) *Condition { return c } -// OrNotCond combine a OR NOT condition to current condition +// OrNotCond combine an OR NOT condition to current condition func (c *Condition) OrNotCond(cond *Condition) *Condition { c = c.clone() if c == cond { @@ -149,5 +153,8 @@ func (c *Condition) IsEmpty() bool { // clone clone a condition func (c Condition) clone() *Condition { + params := make([]condValue, len(c.params)) + copy(params, c.params) + c.params = params return &c } diff --git a/orm/orm_log.go b/client/orm/orm_log.go similarity index 69% rename from orm/orm_log.go rename to client/orm/orm_log.go index f107bb59ef..7c2ce3a2ff 100644 --- a/orm/orm_log.go +++ b/client/orm/orm_log.go @@ -19,49 +19,51 @@ import ( "database/sql" "fmt" "io" - "log" "strings" "time" + + "github.com/beego/beego/v2/client/orm/internal/logs" ) -// Log implement the log.Logger -type Log struct { - *log.Logger +type Log = logs.Log + +// NewLog Set io.Writer to create a Logger. +func NewLog(out io.Writer) *logs.Log { + return logs.NewLog(out) } -//costomer log func +// LogFunc costomer log func var LogFunc func(query map[string]interface{}) -// NewLog set io.Writer to create a Logger. -func NewLog(out io.Writer) *Log { - d := new(Log) - d.Logger = log.New(out, "[ORM]", log.LstdFlags) - return d -} - -func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { - var logMap = make(map[string]interface{}) - sub := time.Now().Sub(t) / 1e5 +func debugLogQueies(alias *DB, operation, query string, t time.Time, err error, args ...interface{}) { + logMap := make(map[string]interface{}) + sub := time.Since(t) / 1e5 elsp := float64(int(sub)) / 10.0 logMap["cost_time"] = elsp - flag := " OK" + flag := "OK" if err != nil { flag = "FAIL" } + logMap["flag"] = flag - con := fmt.Sprintf(" -[Queries/%s] - [%s / %11s / %7.1fms] - [%s]", alias.Name, flag, operaton, elsp, query) + con := fmt.Sprintf(" -[Queries/%s] - [ %s / %11s / %7.1fms] - [%s]", alias.name, flag, operation, elsp, query) + logMap["alias_name"] = alias.name + logMap["operation"] = operation + logMap["query"] = query cons := make([]string, 0, len(args)) for _, arg := range args { cons = append(cons, fmt.Sprintf("%v", arg)) } if len(cons) > 0 { con += fmt.Sprintf(" - `%s`", strings.Join(cons, "`, `")) + logMap["cons"] = cons } if err != nil { con += " - " + err.Error() + logMap["err"] = err } logMap["sql"] = fmt.Sprintf("%s-`%s`", query, strings.Join(cons, "`, `")) - if LogFunc != nil{ + if LogFunc != nil { LogFunc(logMap) } DebugLog.Println(con) @@ -70,7 +72,7 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error // statement query logger struct. // if dev mode, use stmtQueryLog, or use stmtQuerier. type stmtQueryLog struct { - alias *alias + alias *DB query string stmt stmtQuerier } @@ -85,27 +87,39 @@ func (d *stmtQueryLog) Close() error { } func (d *stmtQueryLog) Exec(args ...interface{}) (sql.Result, error) { + return d.ExecContext(context.Background(), args...) +} + +func (d *stmtQueryLog) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { a := time.Now() - res, err := d.stmt.Exec(args...) + res, err := d.stmt.ExecContext(ctx, args...) debugLogQueies(d.alias, "st.Exec", d.query, a, err, args...) return res, err } func (d *stmtQueryLog) Query(args ...interface{}) (*sql.Rows, error) { + return d.QueryContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { a := time.Now() - res, err := d.stmt.Query(args...) + res, err := d.stmt.QueryContext(ctx, args...) debugLogQueies(d.alias, "st.Query", d.query, a, err, args...) return res, err } func (d *stmtQueryLog) QueryRow(args ...interface{}) *sql.Row { + return d.QueryRowContext(context.Background(), args...) +} + +func (d *stmtQueryLog) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { a := time.Now() res := d.stmt.QueryRow(args...) debugLogQueies(d.alias, "st.QueryRow", d.query, a, nil, args...) return res } -func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { +func newStmtQueryLog(alias *DB, stmt stmtQuerier, query string) stmtQuerier { d := new(stmtQueryLog) d.stmt = stmt d.alias = alias @@ -116,21 +130,20 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier { // database query logger struct. // if dev mode, use dbQueryLog, or use dbQuerier. type dbQueryLog struct { - alias *alias + alias *DB db dbQuerier tx txer txe txEnder } -var _ dbQuerier = new(dbQueryLog) -var _ txer = new(dbQueryLog) -var _ txEnder = new(dbQueryLog) +var ( + _ dbQuerier = new(dbQueryLog) + _ txer = new(dbQueryLog) + _ txEnder = new(dbQueryLog) +) func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { - a := time.Now() - stmt, err := d.db.Prepare(query) - debugLogQueies(d.alias, "db.Prepare", query, a, err) - return stmt, err + return d.PrepareContext(context.Background(), query) } func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { @@ -141,10 +154,7 @@ func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stm } func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { - a := time.Now() - res, err := d.db.Exec(query, args...) - debugLogQueies(d.alias, "db.Exec", query, a, err, args...) - return res, err + return d.ExecContext(context.Background(), query, args...) } func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { @@ -155,10 +165,7 @@ func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...inte } func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { - a := time.Now() - res, err := d.db.Query(query, args...) - debugLogQueies(d.alias, "db.Query", query, a, err, args...) - return res, err + return d.QueryContext(context.Background(), query, args...) } func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { @@ -169,10 +176,7 @@ func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...int } func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { - a := time.Now() - res := d.db.QueryRow(query, args...) - debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...) - return res + return d.QueryRowContext(context.Background(), query, args...) } func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { @@ -183,10 +187,7 @@ func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ... } func (d *dbQueryLog) Begin() (*sql.Tx, error) { - a := time.Now() - tx, err := d.db.(txer).Begin() - debugLogQueies(d.alias, "db.Begin", "START TRANSACTION", a, err) - return tx, err + return d.BeginTx(context.Background(), nil) } func (d *dbQueryLog) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { @@ -210,11 +211,18 @@ func (d *dbQueryLog) Rollback() error { return err } +func (d *dbQueryLog) RollbackUnlessCommit() error { + a := time.Now() + err := d.db.(txEnder).RollbackUnlessCommit() + debugLogQueies(d.alias, "tx.RollbackUnlessCommit", "ROLLBACK UNLESS COMMIT", a, err) + return err +} + func (d *dbQueryLog) SetDB(db dbQuerier) { d.db = db } -func newDbQueryLog(alias *alias, db dbQuerier) dbQuerier { +func newDbQueryLog(alias *DB, db dbQuerier) dbQuerier { d := new(dbQueryLog) d.alias = alias d.db = db diff --git a/orm/orm_object.go b/client/orm/orm_object.go similarity index 67% rename from orm/orm_object.go rename to client/orm/orm_object.go index de3181ce2b..83a5ac228a 100644 --- a/orm/orm_object.go +++ b/client/orm/orm_object.go @@ -15,14 +15,17 @@ package orm import ( + "context" "fmt" "reflect" + + "github.com/beego/beego/v2/client/orm/internal/models" ) // an insert queryer struct type insertSet struct { - mi *modelInfo - orm *orm + mi *models.ModelInfo + orm *ormBase stmt stmtQuerier closed bool } @@ -31,29 +34,33 @@ var _ Inserter = new(insertSet) // insert model ignore it's registered or not. func (o *insertSet) Insert(md interface{}) (int64, error) { + return o.InsertWithCtx(context.Background(), md) +} + +func (o *insertSet) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) { if o.closed { return 0, ErrStmtClosed } val := reflect.ValueOf(md) ind := reflect.Indirect(val) typ := ind.Type() - name := getFullName(typ) + name := models.GetFullName(typ) if val.Kind() != reflect.Ptr { panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", name)) } - if name != o.mi.fullName { - panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.fullName, name)) + if name != o.mi.FullName { + panic(fmt.Errorf(" need model `%s` but found `%s`", o.mi.FullName, name)) } - id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind, o.orm.alias.TZ) + id, err := o.orm.alias.dbBaser.InsertStmt(ctx, o.stmt, o.mi, ind, o.orm.alias.tz) if err != nil { return id, err } if id > 0 { - if o.mi.fields.pk.auto { - if o.mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id)) + if o.mi.Fields.Pk.Auto { + if o.mi.Fields.Pk.FieldType&IsPositiveIntegerField > 0 { + ind.FieldByIndex(o.mi.Fields.Pk.FieldIndex).SetUint(uint64(id)) } else { - ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id) + ind.FieldByIndex(o.mi.Fields.Pk.FieldIndex).SetInt(id) } } } @@ -70,11 +77,11 @@ func (o *insertSet) Close() error { } // create new insert queryer. -func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) { +func newInsertSet(ctx context.Context, orm *ormBase, mi *models.ModelInfo) (Inserter, error) { bi := new(insertSet) bi.orm = orm bi.mi = mi - st, query, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi) + st, query, err := orm.alias.dbBaser.PrepareInsert(ctx, orm.db, mi) if err != nil { return nil, err } diff --git a/orm/orm_querym2m.go b/client/orm/orm_querym2m.go similarity index 56% rename from orm/orm_querym2m.go rename to client/orm/orm_querym2m.go index 6a270a0d86..4811e39b58 100644 --- a/orm/orm_querym2m.go +++ b/client/orm/orm_querym2m.go @@ -14,40 +14,50 @@ package orm -import "reflect" +import ( + "context" + "reflect" + + "github.com/beego/beego/v2/client/orm/internal/models" +) // model to model struct type queryM2M struct { md interface{} - mi *modelInfo - fi *fieldInfo + mi *models.ModelInfo + fi *models.FieldInfo qs *querySet ind reflect.Value } // add models to origin models when creating queryM2M. // example: -// m2m := orm.QueryM2M(post,"Tag") -// m2m.Add(&Tag1{},&Tag2{}) -// for _,tag := range post.Tags{} +// +// m2m := orm.QueryM2M(post,"Tag") +// m2m.Add(&Tag1{},&Tag2{}) +// for _,tag := range post.Tags{} // // make sure the relation is defined in post model struct tag. func (o *queryM2M) Add(mds ...interface{}) (int64, error) { + return o.AddWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) AddWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi - mi := fi.relThroughModelInfo - mfi := fi.reverseFieldInfo - rfi := fi.reverseFieldInfoTwo + mi := fi.RelThroughModelInfo + mfi := fi.ReverseFieldInfo + rfi := fi.ReverseFieldInfoTwo orm := o.qs.orm - dbase := orm.alias.DbBaser + dbase := orm.alias.dbBaser var models []interface{} var otherValues []interface{} var otherNames []string - for _, colname := range mi.fields.dbcols { - if colname != mfi.column && colname != rfi.column && colname != fi.mi.fields.pk.column && - mi.fields.columns[colname] != mi.fields.pk { + for _, colname := range mi.Fields.DBcols { + if colname != mfi.Column && colname != rfi.Column && colname != fi.Mi.Fields.Pk.Column && + mi.Fields.Columns[colname] != mi.Fields.Pk { otherNames = append(otherNames, colname) } } @@ -76,7 +86,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { panic(ErrMissPK) } - names := []string{mfi.column, rfi.column} + names := []string{mfi.Column, rfi.Column} values := make([]interface{}, 0, len(models)*2) for _, md := range models { @@ -86,7 +96,7 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { if ind.Kind() != reflect.Struct { v2 = ind.Interface() } else { - _, v2, exist = getExistPk(fi.relModelInfo, ind) + _, v2, exist = getExistPk(fi.RelModelInfo, ind) if !exist { panic(ErrMissPK) } @@ -96,45 +106,61 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { } names = append(names, otherNames...) values = append(values, otherValues...) - return dbase.InsertValue(orm.db, mi, true, names, values) + return dbase.InsertValue(ctx, orm.db, mi, true, names, values) } // remove models following the origin model relationship func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { + return o.RemoveWithCtx(context.Background(), mds...) +} + +func (o *queryM2M) RemoveWithCtx(ctx context.Context, mds ...interface{}) (int64, error) { fi := o.fi - qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md) + qs := o.qs.Filter(fi.ReverseFieldInfo.Name, o.md) - return qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete() + return qs.Filter(fi.ReverseFieldInfoTwo.Name+ExprSep+"in", mds).Delete() } // check model is existed in relationship of origin model func (o *queryM2M) Exist(md interface{}) bool { + return o.ExistWithCtx(context.Background(), md) +} + +func (o *queryM2M) ExistWithCtx(ctx context.Context, md interface{}) bool { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md). - Filter(fi.reverseFieldInfoTwo.name, md).Exist() + return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md). + Filter(fi.ReverseFieldInfoTwo.Name, md).ExistWithCtx(ctx) } -// clean all models in related of origin model +// Clean All models in related of origin model func (o *queryM2M) Clear() (int64, error) { + return o.ClearWithCtx(context.Background()) +} + +func (o *queryM2M) ClearWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete() + return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).DeleteWithCtx(ctx) } -// count all related models of origin model +// count All related models of origin model func (o *queryM2M) Count() (int64, error) { + return o.CountWithCtx(context.Background()) +} + +func (o *queryM2M) CountWithCtx(ctx context.Context) (int64, error) { fi := o.fi - return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count() + return o.qs.Filter(fi.ReverseFieldInfo.Name, o.md).CountWithCtx(ctx) } var _ QueryM2Mer = new(queryM2M) // create new M2M queryer. -func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer { +func newQueryM2M(md interface{}, o *ormBase, mi *models.ModelInfo, fi *models.FieldInfo, ind reflect.Value) QueryM2Mer { qm2m := new(queryM2M) qm2m.md = md qm2m.mi = mi qm2m.fi = fi qm2m.ind = ind - qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet) + qm2m.qs = newQuerySet(o, fi.RelThroughModelInfo).(*querySet) return qm2m } diff --git a/client/orm/orm_queryset.go b/client/orm/orm_queryset.go new file mode 100644 index 0000000000..6bf8212cf2 --- /dev/null +++ b/client/orm/orm_queryset.go @@ -0,0 +1,388 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "fmt" + + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" +) + +type colValue struct { + value int64 + opt operator +} + +type operator int + +// define Col operations +const ( + ColAdd operator = iota + ColMinus + ColMultiply + ColExcept + ColBitAnd + ColBitRShift + ColBitLShift + ColBitXOR + ColBitOr +) + +// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: +// +// Params{ +// "Nums": ColValue(Col_Add, 10), +// } +func ColValue(opt operator, value interface{}) interface{} { + switch opt { + case ColAdd, ColMinus, ColMultiply, ColExcept, ColBitAnd, ColBitRShift, + ColBitLShift, ColBitXOR, ColBitOr: + default: + panic(fmt.Errorf("orm.ColValue wrong operator")) + } + v, err := utils.StrTo(utils.ToStr(value)).Int64() + if err != nil { + panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) + } + var val colValue + val.value = v + val.opt = opt + return val +} + +// real query struct +type querySet struct { + mi *models.ModelInfo + cond *Condition + related []string + relDepth int + limit int64 + offset int64 + groups []string + orders []*order_clause.Order + distinct bool + forUpdate bool + useIndex int + indexes []string + orm *ormBase + aggregate string +} + +var _ QuerySeter = new(querySet) + +// add condition expression to QuerySeter. +func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.And(expr, args...) + return &o +} + +// add raw sql to querySeter. +func (o querySet) FilterRaw(expr string, sql string) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.Raw(expr, sql) + return &o +} + +// add NOT condition to querySeter. +func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { + if o.cond == nil { + o.cond = NewCondition() + } + o.cond = o.cond.AndNot(expr, args...) + return &o +} + +// Set offset number +func (o *querySet) setOffset(num interface{}) { + o.offset = utils.ToInt64(num) +} + +// add LIMIT value. +// args[0] means offset, e.g. LIMIT num,offset. +func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { + o.limit = utils.ToInt64(limit) + if len(args) > 0 { + o.setOffset(args[0]) + } + return &o +} + +// add OFFSET value +func (o querySet) Offset(offset interface{}) QuerySeter { + o.setOffset(offset) + return &o +} + +// add GROUP expression +func (o querySet) GroupBy(exprs ...string) QuerySeter { + o.groups = exprs + return &o +} + +// add ORDER expression. +// "column" means ASC, "-column" means DESC. +func (o querySet) OrderBy(expressions ...string) QuerySeter { + if len(expressions) <= 0 { + return &o + } + o.orders = order_clause.ParseOrder(expressions...) + return &o +} + +// add ORDER expression. +func (o querySet) OrderClauses(orders ...*order_clause.Order) QuerySeter { + if len(orders) <= 0 { + return &o + } + o.orders = orders + return &o +} + +// add DISTINCT to SELECT +func (o querySet) Distinct() QuerySeter { + o.distinct = true + return &o +} + +// add FOR UPDATE to SELECT +func (o querySet) ForUpdate() QuerySeter { + o.forUpdate = true + return &o +} + +// ForceIndex force index for query +func (o querySet) ForceIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyForceIndex + o.indexes = indexes + return &o +} + +// UseIndex use index for query +func (o querySet) UseIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyUseIndex + o.indexes = indexes + return &o +} + +// IgnoreIndex ignore index for query +func (o querySet) IgnoreIndex(indexes ...string) QuerySeter { + o.useIndex = hints.KeyIgnoreIndex + o.indexes = indexes + return &o +} + +// Set relation model to query together. +// it will query relation models and assign to parent model. +func (o querySet) RelatedSel(params ...interface{}) QuerySeter { + if len(params) == 0 { + o.relDepth = DefaultRelsDepth + } else { + for _, p := range params { + switch val := p.(type) { + case string: + o.related = append(o.related, val) + case int: + o.relDepth = val + default: + panic(fmt.Errorf(" wrong param kind: %v", val)) + } + } + } + return &o +} + +// Set condition to QuerySeter. +func (o querySet) SetCond(cond *Condition) QuerySeter { + o.cond = cond + return &o +} + +// Get condition from QuerySeter +func (o querySet) GetCond() *Condition { + return o.cond +} + +// return QuerySeter execution result number +func (o querySet) Count() (int64, error) { + return o.CountWithCtx(context.Background()) +} + +func (o querySet) CountWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.dbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.tz) +} + +// check result empty or not after QuerySeter executed +func (o querySet) Exist() bool { + return o.ExistWithCtx(context.Background()) +} + +func (o querySet) ExistWithCtx(ctx context.Context) bool { + cnt, _ := o.orm.alias.dbBaser.Count(ctx, o.orm.db, o, o.mi, o.cond, o.orm.alias.tz) + return cnt > 0 +} + +// execute update with parameters +func (o querySet) Update(values Params) (int64, error) { + return o.UpdateWithCtx(context.Background(), values) +} + +func (o querySet) UpdateWithCtx(ctx context.Context, values Params) (int64, error) { + return o.orm.alias.dbBaser.UpdateBatch(ctx, o.orm.db, &o, o.mi, o.cond, values, o.orm.alias.tz) +} + +// execute delete +func (o querySet) Delete() (int64, error) { + return o.DeleteWithCtx(context.Background()) +} + +func (o querySet) DeleteWithCtx(ctx context.Context) (int64, error) { + return o.orm.alias.dbBaser.DeleteBatch(ctx, o.orm.db, &o, o.mi, o.cond, o.orm.alias.tz) +} + +// PrepareInsert return an insert queryer. +// it can be used in times. +// example: +// +// i,err := sq.PrepareInsert() +// i.Add(&user1{},&user2{}) +func (o querySet) PrepareInsert() (Inserter, error) { + return o.PrepareInsertWithCtx(context.Background()) +} + +func (o querySet) PrepareInsertWithCtx(ctx context.Context) (Inserter, error) { + return newInsertSet(ctx, o.orm, o.mi) +} + +// All query all data and map to containers. +// cols means the Columns when querying. +func (o querySet) All(container interface{}, cols ...string) (int64, error) { + return o.AllWithCtx(context.Background(), container, cols...) +} + +// AllWithCtx see All +func (o querySet) AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) { + return o.orm.alias.dbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.tz, cols) +} + +// One query one row data and map to containers. +// cols means the Columns when querying. +func (o querySet) One(container interface{}, cols ...string) error { + return o.OneWithCtx(context.Background(), container, cols...) +} + +// OneWithCtx check One +func (o querySet) OneWithCtx(ctx context.Context, container interface{}, cols ...string) error { + o.limit = 1 + num, err := o.orm.alias.dbBaser.ReadBatch(ctx, o.orm.db, o, o.mi, o.cond, container, o.orm.alias.tz, cols) + if err != nil { + return err + } + if num == 0 { + return ErrNoRows + } + + if num > 1 { + return ErrMultiRows + } + return nil +} + +// Values query All data and map to []map[string]interface. +// expres means condition expression. +// it converts data to []map[column]value. +func (o querySet) Values(results *[]Params, exprs ...string) (int64, error) { + return o.ValuesWithCtx(context.Background(), results, exprs...) +} + +// ValuesWithCtx see Values +func (o querySet) ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) { + return o.orm.alias.dbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.tz) +} + +// ValuesList query data and map to [][]interface +// it converts data to [][column_index]value +func (o querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { + return o.ValuesListWithCtx(context.Background(), results, exprs...) +} + +func (o querySet) ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) { + return o.orm.alias.dbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.tz) +} + +// ValuesFlat query all data and map to []interface. +// it's designed for one row record Set, auto change to []value, not [][column]value. +func (o querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { + return o.ValuesFlatWithCtx(context.Background(), result, expr) +} + +// ValuesFlatWithCtx see ValuesFlat +func (o querySet) ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) { + return o.orm.alias.dbBaser.ReadValues(ctx, o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.tz) +} + +// RowsToMap query rows into map[string]interface with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } +func (o querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) +} + +// RowsToStruct query rows into struct with specify key and value column name. +// keyCol = "name", valueCol = "value" +// table data +// name | value +// total | 100 +// found | 200 +// +// to struct { +// Total int +// Found int +// } +func (o querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { + panic(ErrNotImplement) +} + +// create new QuerySeter. +func newQuerySet(orm *ormBase, mi *models.ModelInfo) QuerySeter { + o := new(querySet) + o.mi = mi + o.orm = orm + return o +} + +// aggregate func +func (o querySet) Aggregate(s string) QuerySeter { + o.aggregate = s + return &o +} diff --git a/orm/orm_raw.go b/client/orm/orm_raw.go similarity index 76% rename from orm/orm_raw.go rename to client/orm/orm_raw.go index 3325a7ea71..b382953d37 100644 --- a/orm/orm_raw.go +++ b/client/orm/orm_raw.go @@ -16,9 +16,13 @@ package orm import ( "database/sql" + "errors" "fmt" "reflect" "time" + + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/internal/utils" ) // raw sql string prepared statement @@ -32,7 +36,8 @@ func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) { if o.closed { return nil, ErrStmtClosed } - return o.stmt.Exec(args...) + flatParams := getFlatParams(nil, args, o.rs.orm.alias.tz) + return o.stmt.Exec(flatParams...) } func (o *rawPrepare) Close() error { @@ -45,7 +50,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) { o.rs = rs query := rs.query - rs.orm.alias.DbBaser.ReplaceMarks(&query) + rs.orm.alias.dbBaser.ReplaceMarks(&query) st, err := rs.orm.db.Prepare(query) if err != nil { @@ -63,12 +68,12 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) { type rawSet struct { query string args []interface{} - orm *orm + orm *ormBase } var _ RawSeter = new(rawSet) -// set args for every query +// Set args for every query func (o rawSet) SetArgs(args ...interface{}) RawSeter { o.args = args return &o @@ -77,13 +82,13 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter { // execute raw sql and return sql.Result func (o *rawSet) Exec() (sql.Result, error) { query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) + o.orm.alias.dbBaser.ReplaceMarks(&query) - args := getFlatParams(nil, o.args, o.orm.alias.TZ) + args := getFlatParams(nil, o.args, o.orm.alias.tz) return o.orm.db.Exec(query, args...) } -// set field value to row container +// Set field value to row container func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { switch ind.Kind() { case reflect.Bool: @@ -92,7 +97,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } else if v, ok := value.(bool); ok { ind.SetBool(v) } else { - v, _ := StrTo(ToStr(value)).Bool() + v, _ := utils.StrTo(utils.ToStr(value)).Bool() ind.SetBool(v) } @@ -100,7 +105,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { if value == nil { ind.SetString("") } else { - ind.SetString(ToStr(value)) + ind.SetString(utils.ToStr(value)) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -114,7 +119,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: ind.SetInt(int64(val.Uint())) default: - v, _ := StrTo(ToStr(value)).Int64() + v, _ := utils.StrTo(utils.ToStr(value)).Int64() ind.SetInt(v) } } @@ -129,7 +134,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: ind.SetUint(val.Uint()) default: - v, _ := StrTo(ToStr(value)).Uint64() + v, _ := utils.StrTo(utils.ToStr(value)).Uint64() ind.SetUint(v) } } @@ -142,7 +147,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { case reflect.Float64: ind.SetFloat(val.Float()) default: - v, _ := StrTo(ToStr(value)).Float64() + v, _ := utils.StrTo(utils.ToStr(value)).Float64() ind.SetFloat(v) } } @@ -157,7 +162,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { var str string switch d := value.(type) { case time.Time: - o.orm.alias.DbBaser.TimeFromDB(&d, o.orm.alias.TZ) + o.orm.alias.dbBaser.TimeFromDB(&d, o.orm.alias.tz) ind.Set(reflect.ValueOf(d)) case []byte: str = string(d) @@ -167,14 +172,20 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { if str != "" { if len(str) >= 19 { str = str[:19] - t, err := time.ParseInLocation(formatDateTime, str, o.orm.alias.TZ) + t, err := time.ParseInLocation(utils.FormatDateTime, str, o.orm.alias.tz) if err == nil { t = t.In(DefaultTimeLoc) ind.Set(reflect.ValueOf(t)) } } else if len(str) >= 10 { str = str[:10] - t, err := time.ParseInLocation(formatDate, str, DefaultTimeLoc) + t, err := time.ParseInLocation(utils.FormatDate, str, DefaultTimeLoc) + if err == nil { + ind.Set(reflect.ValueOf(t)) + } + } else if len(str) >= 8 { + str = str[:8] + t, err := time.ParseInLocation(utils.FormatTime, str, DefaultTimeLoc) if err == nil { ind.Set(reflect.ValueOf(t)) } @@ -202,7 +213,7 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) { } } -// set field value in loop for slice container +// Set field value in loop for slice container func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) { nInds := *nIndsPtr @@ -244,7 +255,6 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr } cur++ } - } else { value := reflect.ValueOf(refs[cur]).Elem().Interface() if isPtr && value == nil { @@ -279,7 +289,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { refs = make([]interface{}, 0, len(containers)) sInds []reflect.Value eTyps []reflect.Type - sMi *modelInfo + sMi *models.ModelInfo ) structMode := false for _, container := range containers { @@ -287,7 +297,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { ind := reflect.Indirect(val) if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" all args must be use ptr")) + panic(errors.New(" All args must be use ptr")) } etyp := ind.Type() @@ -301,12 +311,12 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { if len(containers) > 1 { - panic(fmt.Errorf(" now support one struct only. see #384")) + panic(errors.New(" now support one struct only. see #384")) } structMode = true - fn := getFullName(typ) - if mi, ok := modelCache.getByFullName(fn); ok { + fn := models.GetFullName(typ) + if mi, ok := models.DefaultModelCache.GetByFullName(fn); ok { sMi = mi } } else { @@ -316,9 +326,9 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { } query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) + o.orm.alias.dbBaser.ReplaceMarks(&query) - args := getFlatParams(nil, o.args, o.orm.alias.TZ) + args := getFlatParams(nil, o.args, o.orm.alias.tz) rows, err := o.orm.db.Query(query, args...) if err != nil { if err == sql.ErrNoRows { @@ -327,6 +337,8 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return err } + structTagMap := make(map[reflect.StructTag]map[string]string) + defer rows.Close() if rows.Next() { @@ -360,31 +372,58 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { if sMi != nil { for _, col := range columns { - if fi := sMi.fields.GetByColumn(col); fi != nil { + if fi := sMi.Fields.GetByColumn(col); fi != nil { value := reflect.ValueOf(columnsMp[col]).Elem().Interface() - field := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsRelField > 0 { - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field := ind.FieldByIndex(fi.FieldIndex) + if fi.FieldType&IsRelField > 0 { + mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type()) field.Set(mf) - field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + field = mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex) + } + if fi.IsFielder { + fd := field.Addr().Interface().(models.Fielder) + err := fd.SetRaw(value) + if err != nil { + return fmt.Errorf("Set raw error: %w", err) + } + } else { + o.setFieldValue(field, value) } - o.setFieldValue(field, value) } } } else { - for i := 0; i < ind.NumField(); i++ { - f := ind.Field(i) - fe := ind.Type().Field(i) - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) - var col string - if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) - } - if v, ok := columnsMp[col]; ok { - value := reflect.ValueOf(v).Elem().Interface() - o.setFieldValue(f, value) + // define recursive function + var recursiveSetField func(rv reflect.Value) + recursiveSetField = func(rv reflect.Value) { + for i := 0; i < rv.NumField(); i++ { + f := rv.Field(i) + fe := rv.Type().Field(i) + + // check if the field is a Struct + // recursive the Struct type + if fe.Type.Kind() == reflect.Struct { + recursiveSetField(f) + } + + // thanks @Gazeboxu. + tags := structTagMap[fe.Tag] + if tags == nil { + _, tags = models.ParseStructTag(fe.Tag.Get(models.DefaultStructTagName)) + structTagMap[fe.Tag] = tags + } + var col string + if col = tags["column"]; col == "" { + col = models.NameStrategyMap[models.NameStrategy](fe.Name) + } + if v, ok := columnsMp[col]; ok { + value := reflect.ValueOf(v).Elem().Interface() + o.setFieldValue(f, value) + } } } + + // init call the recursive function + recursiveSetField(ind) } } else { @@ -399,7 +438,6 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { sInd.Set(nInd) } } - } else { return ErrNoRows } @@ -407,20 +445,20 @@ func (o *rawSet) QueryRow(containers ...interface{}) error { return nil } -// query data rows and map to container +// QueryRows query data rows and map to container func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { var ( refs = make([]interface{}, 0, len(containers)) sInds []reflect.Value eTyps []reflect.Type - sMi *modelInfo + sMi *models.ModelInfo ) structMode := false for _, container := range containers { val := reflect.ValueOf(container) sInd := reflect.Indirect(val) if val.Kind() != reflect.Ptr || sInd.Kind() != reflect.Slice { - panic(fmt.Errorf(" all args must be use ptr slice")) + panic(errors.New(" All args must be use ptr slice")) } etyp := sInd.Type().Elem() @@ -434,12 +472,12 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { if typ.Kind() == reflect.Struct && typ.String() != "time.Time" { if len(containers) > 1 { - panic(fmt.Errorf(" now support one struct only. see #384")) + panic(errors.New(" now support one struct only. see #384")) } structMode = true - fn := getFullName(typ) - if mi, ok := modelCache.getByFullName(fn); ok { + fn := models.GetFullName(typ) + if mi, ok := models.DefaultModelCache.GetByFullName(fn); ok { sMi = mi } } else { @@ -449,9 +487,9 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { } query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) + o.orm.alias.dbBaser.ReplaceMarks(&query) - args := getFlatParams(nil, o.args, o.orm.alias.TZ) + args := getFlatParams(nil, o.args, o.orm.alias.tz) rows, err := o.orm.db.Query(query, args...) if err != nil { return 0, err @@ -464,7 +502,6 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { sInd := sInds[0] for rows.Next() { - if structMode { columns, err := rows.Columns() if err != nil { @@ -501,15 +538,23 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { if sMi != nil { for _, col := range columns { - if fi := sMi.fields.GetByColumn(col); fi != nil { + if fi := sMi.Fields.GetByColumn(col); fi != nil { value := reflect.ValueOf(columnsMp[col]).Elem().Interface() - field := ind.FieldByIndex(fi.fieldIndex) - if fi.fieldType&IsRelField > 0 { - mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) + field := ind.FieldByIndex(fi.FieldIndex) + if fi.FieldType&IsRelField > 0 { + mf := reflect.New(fi.RelModelInfo.AddrField.Elem().Type()) field.Set(mf) - field = mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) + field = mf.Elem().FieldByIndex(fi.RelModelInfo.Fields.Pk.FieldIndex) + } + if fi.IsFielder { + fd := field.Addr().Interface().(models.Fielder) + err := fd.SetRaw(value) + if err != nil { + return 0, fmt.Errorf("Set raw error: %w", err) + } + } else { + o.setFieldValue(field, value) } - o.setFieldValue(field, value) } } } else { @@ -526,10 +571,10 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { recursiveSetField(f) } - _, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) + _, tags := models.ParseStructTag(fe.Tag.Get(models.DefaultStructTagName)) var col string if col = tags["column"]; col == "" { - col = nameStrategyMap[nameStrategy](fe.Name) + col = models.NameStrategyMap[models.NameStrategy](fe.Name) } if v, ok := columnsMp[col]; ok { value := reflect.ValueOf(v).Elem().Interface() @@ -549,18 +594,19 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { sInd = reflect.Append(sInd, ind) } else { - if err := rows.Scan(refs...); err != nil { + if err = rows.Scan(refs...); err != nil { return 0, err } - o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0) } - cnt++ } - if cnt > 0 { + if err = rows.Err(); err != nil { + return 0, err + } + if cnt > 0 { if structMode { sInds[0].Set(sInd) } else { @@ -594,9 +640,9 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er } query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) + o.orm.alias.dbBaser.ReplaceMarks(&query) - args := getFlatParams(nil, o.args, o.orm.alias.TZ) + args := getFlatParams(nil, o.args, o.orm.alias.tz) var rs *sql.Rows rs, err := o.orm.db.Query(query, args...) @@ -687,6 +733,10 @@ func (o *rawSet) readValues(container interface{}, needCols []string) (int64, er cnt++ } + if err = rs.Err(); err != nil { + return 0, err + } + switch v := container.(type) { case *[]Params: *v = maps @@ -721,9 +771,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in } query := o.query - o.orm.alias.DbBaser.ReplaceMarks(&query) + o.orm.alias.dbBaser.ReplaceMarks(&query) - args := getFlatParams(nil, o.args, o.orm.alias.TZ) + args := getFlatParams(nil, o.args, o.orm.alias.tz) rs, err := o.orm.db.Query(query, args...) if err != nil { @@ -794,7 +844,7 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in } default: - if id := ind.FieldByName(camelString(key)); id.IsValid() { + if id := ind.FieldByName(models.CamelString(key)); id.IsValid() { o.setFieldValue(id, reflect.ValueOf(refs[valueIndex]).Elem().Interface()) } } @@ -802,6 +852,9 @@ func (o *rawSet) queryRowsTo(container interface{}, keyCol, valueCol string) (in cnt++ } + if err = rs.Err(); err != nil { + return 0, err + } if typ == 1 { v, _ := container.(*Params) *v = maps @@ -825,30 +878,32 @@ func (o *rawSet) ValuesFlat(container *ParamsList, cols ...string) (int64, error return o.readValues(container, cols) } -// query all rows into map[string]interface with specify key and value column name. +// query All rows into map[string]interface with specify key and value column name. // keyCol = "name", valueCol = "value" // table data // name | value // total | 100 // found | 200 -// to map[string]interface{}{ -// "total": 100, -// "found": 200, -// } +// +// to map[string]interface{}{ +// "total": 100, +// "found": 200, +// } func (o *rawSet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { return o.queryRowsTo(result, keyCol, valueCol) } -// query all rows into struct with specify key and value column name. +// query All rows into struct with specify key and value column name. // keyCol = "name", valueCol = "value" // table data // name | value // total | 100 // found | 200 -// to struct { -// Total int -// Found int -// } +// +// to struct { +// Total int +// Found int +// } func (o *rawSet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { return o.queryRowsTo(ptrStruct, keyCol, valueCol) } @@ -858,7 +913,7 @@ func (o *rawSet) Prepare() (RawPreparer, error) { return newRawPreparer(o) } -func newRawSet(orm *orm, query string, args []interface{}) RawSeter { +func newRawSet(orm *ormBase, query string, args []interface{}) RawSeter { o := new(rawSet) o.query = query o.args = args diff --git a/orm/orm_test.go b/client/orm/orm_test.go similarity index 73% rename from orm/orm_test.go rename to client/orm/orm_test.go index bdb430b677..ade877994f 100644 --- a/orm/orm_test.go +++ b/client/orm/orm_test.go @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.8 - package orm import ( @@ -21,7 +19,6 @@ import ( "context" "database/sql" "fmt" - "io/ioutil" "math" "os" "path/filepath" @@ -30,19 +27,30 @@ import ( "strings" "testing" "time" + + _ "github.com/mattn/go-sqlite3" + + "github.com/beego/beego/v2/client/orm/internal/utils" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/client/orm/hints" ) var _ = os.PathSeparator var ( - testDate = formatDate + " -0700" - testDateTime = formatDateTime + " -0700" - testTime = formatTime + " -0700" + testDate = utils.FormatDate + " -0700" + testDateTime = utils.FormatDateTime + " -0700" + testTime = utils.FormatTime + " -0700" ) type argAny []interface{} -// get interface by index from interface slice +// Get interface by index from interface slice func (a argAny) Get(i int, args ...interface{}) (r interface{}) { if i >= 0 && i < len(a) { r = a[i] @@ -66,7 +74,7 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er case time.Time: if v2, vo := b.(time.Time); vo { if arg.Get(1) != nil { - format := ToStr(arg.Get(1)) + format := utils.ToStr(arg.Get(1)) a = v.Format(format) b = v2.Format(format) ok = a == b @@ -76,15 +84,11 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (ok bool, err er } } default: - ok = ToStr(a) == ToStr(b) + ok = utils.ToStr(a) == utils.ToStr(b) } ok = is && ok || !is && !ok if !ok { - if is { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } else { - err = fmt.Errorf("expected: `%v`, get `%v`", b, a) - } + err = fmt.Errorf("expected: `%v`, Get `%v`", b, a) } wrongArg: @@ -113,7 +117,7 @@ func getCaller(skip int) string { pc, file, line, _ := runtime.Caller(skip) fun := runtime.FuncForPC(pc) _, fn := filepath.Split(file) - data, err := ioutil.ReadFile(file) + data, err := os.ReadFile(file) var codes []string if err == nil { lines := bytes.Split(data, []byte{'\n'}) @@ -128,7 +132,8 @@ func getCaller(skip int) string { if cur == line { flag = ">>" } - code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.Replace(string(lines[o+i]), "\t", " ", -1)) + ls := formatLines(string(lines[o+i])) + code := fmt.Sprintf(" %s %5d: %s", flag, cur, ls) if code != "" { codes = append(codes, code) } @@ -141,6 +146,11 @@ func getCaller(skip int) string { return fmt.Sprintf("%s:%s:%d: \n%s", fn, funName, line, strings.Join(codes, "\n")) } +func formatLines(s string) string { + return strings.ReplaceAll(s, "\t", " ") +} + +// Deprecated: Using stretchr/testify/assert func throwFail(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) @@ -156,6 +166,7 @@ func throwFail(t *testing.T, err error, args ...interface{}) { } } +// deprecated using assert.XXX func throwFailNow(t *testing.T, err error, args ...interface{}) { if err != nil { con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2)) @@ -175,16 +186,17 @@ func TestGetDB(t *testing.T) { if db, err := GetDB(); err != nil { throwFailNow(t, err) } else { - err = db.Ping() + err = db.DB.Ping() throwFailNow(t, err) } } -func TestSyncDb(t *testing.T) { +func registerAllModel() { RegisterModel(new(Data), new(DataNull), new(DataCustom)) RegisterModel(new(User)) RegisterModel(new(Profile)) RegisterModel(new(Post)) + RegisterModel(new(NullValue)) RegisterModel(new(Tag)) RegisterModel(new(Comment)) RegisterModel(new(UserBig)) @@ -197,48 +209,42 @@ func TestSyncDb(t *testing.T) { RegisterModel(new(IntegerPk)) RegisterModel(new(UintPk)) RegisterModel(new(PtrPk)) + RegisterModel(new(Index)) + RegisterModel(new(StrPk)) + RegisterModel(new(TM)) + RegisterModel(new(DeptInfo)) + RegisterModel(new(testTab), new(testTab1), new(testTab2)) +} + +func TestSyncDb(t *testing.T) { + models.DefaultModelCache.Clean() + registerAllModel() err := RunSyncdb("default", true, Debug) throwFail(t, err) - - modelCache.clean() } -func TestRegisterModels(t *testing.T) { - RegisterModel(new(Data), new(DataNull), new(DataCustom)) - RegisterModel(new(User)) - RegisterModel(new(Profile)) - RegisterModel(new(Post)) - RegisterModel(new(Tag)) - RegisterModel(new(Comment)) - RegisterModel(new(UserBig)) - RegisterModel(new(PostTags)) - RegisterModel(new(Group)) - RegisterModel(new(Permission)) - RegisterModel(new(GroupPermissions)) - RegisterModel(new(InLine)) - RegisterModel(new(InLineOneToOne)) - RegisterModel(new(IntegerPk)) - RegisterModel(new(UintPk)) - RegisterModel(new(PtrPk)) +func TestRegisterModels(_ *testing.T) { + models.DefaultModelCache.Clean() + registerAllModel() BootStrap() dORM = NewOrm() - dDbBaser = getDbAlias("default").DbBaser + dDbBaser = getDB("default").dbBaser } func TestModelSyntax(t *testing.T) { user := &User{} ind := reflect.ValueOf(user).Elem() - fn := getFullName(ind.Type()) - mi, ok := modelCache.getByFullName(fn) + fn := models.GetFullName(ind.Type()) + _, ok := models.DefaultModelCache.GetByFullName(fn) throwFail(t, AssertIs(ok, true)) - mi, ok = modelCache.get("user") + mi, ok := models.DefaultModelCache.Get("user") throwFail(t, AssertIs(ok, true)) if ok { - throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, true)) + throwFail(t, AssertIs(mi.Fields.GetByName("ShouldSkip") == nil, true)) } } @@ -262,7 +268,7 @@ var DataValues = map[string]interface{}{ "Uint8": uint8(1<<8 - 1), "Uint16": uint16(1<<16 - 1), "Uint32": uint32(1<<32 - 1), - "Uint64": uint64(1<<63 - 1), // uint64 values with high bit set are not supported + "Uint64": uint64(1<<63 - 1), // uint64 values with high bit Set are not supported "Float32": float32(100.1234), "Float64": float64(100.1234), "Decimal": float64(100.1234), @@ -294,16 +300,97 @@ func TestDataTypes(t *testing.T) { vu := e.Interface() switch name { case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) + assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) + default: + assert.Equal(t, value, vu) } - throwFail(t, AssertIs(vu == value, true), value, vu) + } +} + +func TestTM(t *testing.T) { + // The precision of sqlite is not implemented + if dORM.Driver().Type() == 2 { + return + } + var recTM TM + tm := NewTM() + tm.TMPrecision1 = time.Unix(1596766024, 123456789) + tm.TMPrecision2 = time.Unix(1596766024, 123456789) + _, err := dORM.Insert(tm) + throwFail(t, err) + + err = dORM.QueryTable("tm").One(&recTM) + throwFail(t, err) + throwFail(t, AssertIs(recTM.TMPrecision1.String(), "2020-08-07 02:07:04.123 +0000 UTC")) + throwFail(t, AssertIs(recTM.TMPrecision2.String(), "2020-08-07 02:07:04.1235 +0000 UTC")) +} + +func TestUnregisterModel(t *testing.T) { + data := []*DeptInfo{ + { + DeptName: "A", + EmployeeName: "A1", + Salary: 1000, + }, + { + DeptName: "A", + EmployeeName: "A2", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B1", + Salary: 2000, + }, + { + DeptName: "B", + EmployeeName: "B2", + Salary: 4000, + }, + { + DeptName: "B", + EmployeeName: "B3", + Salary: 3000, + }, + } + qs := dORM.QueryTable("dept_info") + i, _ := qs.PrepareInsert() + for _, d := range data { + _, err := i.Insert(d) + if err != nil { + throwFail(t, err) + } + } + + f := func() { + var res []UnregisterModel + n, err := dORM.QueryTable("dept_info").All(&res) + throwFail(t, err) + throwFail(t, AssertIs(n, 5)) + throwFail(t, AssertIs(res[0].EmployeeName, "A1")) + + type Sum struct { + DeptName string + Total int + } + var sun []Sum + qs.Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").OrderBy("dept_name").All(&sun) + throwFail(t, AssertIs(sun[0].DeptName, "A")) + throwFail(t, AssertIs(sun[0].Total, 3000)) + + type Max struct { + DeptName string + Max float64 + } + var max []Max + qs.Aggregate("dept_name,max(salary) as max").GroupBy("dept_name").OrderBy("dept_name").All(&max) + throwFail(t, AssertIs(max[1].DeptName, "B")) + throwFail(t, AssertIs(max[1].Max, 4000)) + } + for i := 0; i < 5; i++ { + f() } } @@ -455,11 +542,13 @@ func TestNullDataTypes(t *testing.T) { throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr)) throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr)) throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr)) - throwFail(t, AssertIs((*d.TimePtr).UTC().Format(testTime), timePtr.UTC().Format(testTime))) - throwFail(t, AssertIs((*d.DatePtr).UTC().Format(testDate), datePtr.UTC().Format(testDate))) - throwFail(t, AssertIs((*d.DateTimePtr).UTC().Format(testDateTime), dateTimePtr.UTC().Format(testDateTime))) - // test support for pointer fields using RawSeter.QueryRows() + // in mysql, there are some precision problem, (*d.TimePtr).UTC() != timePtr.UTC() + assert.True(t, (*d.TimePtr).UTC().Sub(timePtr.UTC()) <= time.Second) + assert.True(t, (*d.DatePtr).UTC().Sub(datePtr.UTC()) <= time.Second) + assert.True(t, (*d.DateTimePtr).UTC().Sub(dateTimePtr.UTC()) <= time.Second) + + // test support for pointer Fields using RawSeter.QueryRows() var dnList []*DataNull Q := dDbBaser.TableQuote() num, err = dORM.Raw(fmt.Sprintf("SELECT * FROM %sdata_null%s where id=?", Q, Q), 3).QueryRows(&dnList) @@ -532,8 +621,9 @@ func TestCRUD(t *testing.T) { throwFail(t, AssertIs(u.Status, 3)) throwFail(t, AssertIs(u.IsStaff, true)) throwFail(t, AssertIs(u.IsActive, true)) - throwFail(t, AssertIs(u.Created.In(DefaultTimeLoc), user.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(u.Updated.In(DefaultTimeLoc), user.Updated.In(DefaultTimeLoc), testDateTime)) + + assert.True(t, u.Created.In(DefaultTimeLoc).Sub(user.Created.In(DefaultTimeLoc)) <= time.Second) + assert.True(t, u.Updated.In(DefaultTimeLoc).Sub(user.Updated.In(DefaultTimeLoc)) <= time.Second) user.UserName = "astaxie" user.Profile = profile @@ -669,7 +759,7 @@ func TestInsertTestData(t *testing.T) { posts := []*Post{ {User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand. -This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`}, +This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, All of which you should read first.`}, {User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`}, {User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide. With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`}, @@ -746,7 +836,6 @@ The program—and web server—godoc processes Go source files to extract docume throwFailNow(t, AssertIs(nums, num)) } } - } func TestCustomField(t *testing.T) { @@ -769,14 +858,29 @@ func TestCustomField(t *testing.T) { throwFailNow(t, AssertIs(user.Extra.Name, "beego")) throwFailNow(t, AssertIs(user.Extra.Data, "orm")) + + var users []User + Q := dDbBaser.TableQuote() + n, err := dORM.Raw(fmt.Sprintf("SELECT * FROM %suser%s where id=?", Q, Q), 2).QueryRows(&users) + throwFailNow(t, err) + throwFailNow(t, AssertIs(n, 1)) + throwFailNow(t, AssertIs(users[0].Extra.Name, "beego")) + throwFailNow(t, AssertIs(users[0].Extra.Data, "orm")) + + user = User{} + err = dORM.Raw(fmt.Sprintf("SELECT * FROM %suser%s where id=?", Q, Q), 2).QueryRow(&user) + throwFailNow(t, err) + throwFailNow(t, AssertIs(user.Extra.Name, "beego")) + throwFailNow(t, AssertIs(user.Extra.Data, "orm")) } func TestExpr(t *testing.T) { user := &User{} - qs := dORM.QueryTable(user) - qs = dORM.QueryTable((*User)(nil)) - qs = dORM.QueryTable("User") - qs = dORM.QueryTable("user") + var qs QuerySeter + assert.NotPanics(t, func() { qs = dORM.QueryTable(user) }) + assert.NotPanics(t, func() { qs = dORM.QueryTable((*User)(nil)) }) + assert.NotPanics(t, func() { qs = dORM.QueryTable("User") }) + assert.NotPanics(t, func() { qs = dORM.QueryTable("user") }) num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -790,6 +894,32 @@ func TestExpr(t *testing.T) { // throwFail(t, AssertIs(num, 3)) } +func TestSpecifyIndex(t *testing.T) { + var index *Index + index = &Index{ + F1: 1, + F2: 2, + } + _, _ = dORM.Insert(index) + throwFailNow(t, AssertIs(index.Id, 1)) + + index = &Index{ + F1: 3, + F2: 4, + } + _, _ = dORM.Insert(index) + throwFailNow(t, AssertIs(index.Id, 2)) + + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).ForceIndex(`index_f1`).One(index) + throwFailNow(t, AssertIs(index.F2, 2)) + + _ = dORM.QueryTable(&Index{}).Filter(`f2`, `4`).UseIndex(`index_f2`).One(index) + throwFailNow(t, AssertIs(index.F1, 3)) + + _ = dORM.QueryTable(&Index{}).Filter(`f1`, `1`).IgnoreIndex(`index_f1`, `index_f2`).One(index) + throwFailNow(t, AssertIs(index.F2, 2)) +} + func TestOperators(t *testing.T) { qs := dORM.QueryTable("user") num, err := qs.Filter("user_name", "slene").Count() @@ -808,6 +938,17 @@ func TestOperators(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 1)) + if IsMysql { + // Now only mysql support `strictexact` + num, err = qs.Filter("user_name__strictexact", "Slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 0)) + + num, err = qs.Filter("user_name__strictexact", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } + num, err = qs.Filter("user_name__contains", "e").Count() throwFail(t, err) throwFail(t, AssertIs(num, 2)) @@ -984,7 +1125,10 @@ func TestOffset(t *testing.T) { throwFail(t, AssertIs(num, 2)) } -func TestOrderBy(t *testing.T) { +func TestCountOrderBy(t *testing.T) { + if IsPostgres { + return + } qs := dORM.QueryTable("user") num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() throwFail(t, err) @@ -997,6 +1141,81 @@ func TestOrderBy(t *testing.T) { num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`profile__age`), + order_clause.SortDescending(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsMysql { + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`rand()`), + order_clause.Raw(), + ), + ).Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } +} + +func TestOrderBy(t *testing.T) { + var users []*User + qs := dORM.QueryTable("user") + num, err := qs.OrderBy("-status").Filter("user_name", "nobody").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("status").Filter("user_name", "slene").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`profile__age`), + order_clause.SortDescending(), + ), + ).Filter("user_name", "astaxie").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + if IsMysql { + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column(`rand()`), + order_clause.Raw(), + ), + ).Filter("user_name", "astaxie").All(&users) + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + } +} + +func TestCount(t *testing.T) { + qs := dORM.QueryTable("user") + num, err := qs.Filter("user_name", "nobody").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "slene").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) + + num, err = qs.Filter("user_name", "astaxie").Count() + throwFail(t, err) + throwFail(t, AssertIs(num, 1)) } func TestAll(t *testing.T) { @@ -1068,7 +1287,6 @@ func TestOne(t *testing.T) { err = qs.Filter("user_name", "nothing").One(&user) throwFail(t, AssertIs(err, ErrNoRows)) - } func TestValues(t *testing.T) { @@ -1083,6 +1301,19 @@ func TestValues(t *testing.T) { throwFail(t, AssertIs(maps[2]["Profile"], nil)) } + num, err = qs.OrderClauses( + order_clause.Clause( + order_clause.Column("Id"), + order_clause.SortAscending(), + ), + ).Values(&maps) + throwFail(t, err) + throwFail(t, AssertIs(num, 3)) + if num == 3 { + throwFail(t, AssertIs(maps[0]["UserName"], "slene")) + throwFail(t, AssertIs(maps[2]["Profile"], nil)) + } + num, err = qs.OrderBy("Id").Values(&maps, "UserName", "Profile__Age") throwFail(t, err) throwFail(t, AssertIs(num, 3)) @@ -1105,8 +1336,8 @@ func TestValuesList(t *testing.T) { throwFail(t, err) throwFail(t, AssertIs(num, 3)) if num == 3 { - throwFail(t, AssertIs(list[0][1], "slene")) - throwFail(t, AssertIs(list[2][9], nil)) + throwFail(t, AssertIs(list[0][1], "slene")) // username + throwFail(t, AssertIs(list[2][10], nil)) // profile } num, err = qs.OrderBy("Id").ValuesList(&list, "UserName", "Profile__Age") @@ -1276,24 +1507,32 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].User.ID, 3)) - num, err = dORM.LoadRelated(&user, "Posts", true) + num, err = dORM.LoadRelated(&user, "Posts", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&user, "Posts", true, 1) + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.Limit(1)) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(user.Posts), 1)) - num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id") + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.OrderBy("-Id")) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(user.Posts), 2)) throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting")) - num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id") + num, err = dORM.LoadRelated(&user, "Posts", + hints.DefaultRelDepth(), + hints.Limit(1), + hints.Offset(1), + hints.OrderBy("Id")) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(len(user.Posts), 1)) @@ -1315,7 +1554,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(profile.User == nil, false)) throwFailNow(t, AssertIs(profile.User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&profile, "User", true) + num, err = dORM.LoadRelated(&profile, "User", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(profile.User == nil, false)) @@ -1332,7 +1571,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(user.Profile == nil, false)) throwFailNow(t, AssertIs(user.Profile.Age, 30)) - num, err = dORM.LoadRelated(&user, "Profile", true) + num, err = dORM.LoadRelated(&user, "Profile", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(user.Profile == nil, false)) @@ -1352,7 +1591,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(post.User == nil, false)) throwFailNow(t, AssertIs(post.User.UserName, "astaxie")) - num, err = dORM.LoadRelated(&post, "User", true) + num, err = dORM.LoadRelated(&post, "User", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 1)) throwFailNow(t, AssertIs(post.User == nil, false)) @@ -1372,7 +1611,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(len(post.Tags), 2)) throwFailNow(t, AssertIs(post.Tags[0].Name, "golang")) - num, err = dORM.LoadRelated(&post, "Tags", true) + num, err = dORM.LoadRelated(&post, "Tags", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 2)) throwFailNow(t, AssertIs(len(post.Tags), 2)) @@ -1393,7 +1632,7 @@ func TestLoadRelated(t *testing.T) { throwFailNow(t, AssertIs(tag.Posts[0].User.ID, 2)) throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true)) - num, err = dORM.LoadRelated(&tag, "Posts", true) + num, err = dORM.LoadRelated(&tag, "Posts", hints.DefaultRelDepth()) throwFailNow(t, err) throwFailNow(t, AssertIs(num, 3)) throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction")) @@ -1532,7 +1771,7 @@ func TestQueryM2M(t *testing.T) { throwFailNow(t, AssertIs(num, 1)) } -func TestQueryRelate(t *testing.T) { +func TestQueryRelate(_ *testing.T) { // post := &Post{Id: 2} // qs := dORM.QueryRelate(post, "Tags") @@ -1656,18 +1895,11 @@ func TestRawQueryRow(t *testing.T) { switch col { case "id": throwFail(t, AssertIs(id, 1)) - case "time": + case "time", "datetime": v = v.(time.Time).In(DefaultTimeLoc) value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testTime)) + assert.True(t, v.(time.Time).Sub(value) <= time.Second) case "date": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDate)) - case "datetime": - v = v.(time.Time).In(DefaultTimeLoc) - value := dataValues[col].(time.Time).In(DefaultTimeLoc) - throwFail(t, AssertIs(v, value, testDateTime)) default: throwFail(t, AssertIs(v, dataValues[col])) } @@ -1689,7 +1921,25 @@ func TestRawQueryRow(t *testing.T) { throwFail(t, AssertIs(*status, 3)) throwFail(t, AssertIs(pid, nil)) - // test for sql.Null* fields + type Embedded struct { + Email string + } + type queryRowNoModelTest struct { + Id int + EmbedField Embedded + } + + cols = []string{ + "id", "email", + } + var row queryRowNoModelTest + query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q) + err = dORM.Raw(query, 4).QueryRow(&row) + throwFail(t, err) + throwFail(t, AssertIs(row.Id, 4)) + throwFail(t, AssertIs(row.EmbedField.Email, "nobody@gmail.com")) + + // test for sql.Null* Fields nData := &DataNull{ NullString: sql.NullString{String: "test sql.null", Valid: true}, NullBool: sql.NullBool{Bool: true, Valid: true}, @@ -1740,16 +1990,12 @@ func TestQueryRows(t *testing.T) { vu := e.Interface() switch name { case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) + default: + assert.Equal(t, value, vu) } - throwFail(t, AssertIs(vu == value, true), value, vu) } var datas2 []Data @@ -1767,16 +2013,13 @@ func TestQueryRows(t *testing.T) { vu := e.Interface() switch name { case "Time": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testTime) case "Date": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDate) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDate) case "DateTime": - vu = vu.(time.Time).In(DefaultTimeLoc).Format(testDateTime) - value = value.(time.Time).In(DefaultTimeLoc).Format(testDateTime) + assert.True(t, vu.(time.Time).In(DefaultTimeLoc).Sub(value.(time.Time).In(DefaultTimeLoc)) <= time.Second) + default: + assert.Equal(t, value, vu) } - throwFail(t, AssertIs(vu == value, true), value, vu) + } var ids []int @@ -1793,7 +2036,7 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(ids[2], 4)) throwFailNow(t, AssertIs(usernames[2], "nobody")) - //test query rows by nested struct + // test query rows by nested struct var l []userProfile query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q) num, err = dORM.Raw(query).QueryRows(&l) @@ -1805,7 +2048,7 @@ func TestQueryRows(t *testing.T) { throwFailNow(t, AssertIs(l[1].UserName, "astaxie")) throwFailNow(t, AssertIs(l[1].Age, 30)) - // test for sql.Null* fields + // test for sql.Null* Fields nData := &DataNull{ NullString: sql.NullString{String: "test sql.null", Valid: true}, NullBool: sql.NullBool{Bool: true, Valid: true}, @@ -1865,69 +2108,77 @@ func TestRawValues(t *testing.T) { } } -func TestRawPrepare(t *testing.T) { - switch { - case IsMysql || IsSqlite: +func TestForIssue4709(t *testing.T) { + pre, err := dORM.Raw("INSERT into null_value (value) VALUES (?)").Prepare() + assert.Nil(t, err) + _, err = pre.Exec(nil) + assert.Nil(t, err) +} - pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() - throwFail(t, err) +func TestRawPrepare(t *testing.T) { + var ( + result sql.Result + err error + pre RawPreparer + ) + if IsMysql || IsSqlite { + pre, err = dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + assert.Nil(t, err) if pre != nil { - r, err := pre.Exec("name1") - throwFail(t, err) + result, err = pre.Exec("name1") + assert.Nil(t, err) - tid, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(tid > 0, true)) + tid, err := result.LastInsertId() + assert.Nil(t, err) + assert.True(t, tid > 0) - r, err = pre.Exec("name2") - throwFail(t, err) + result, err = pre.Exec("name2") + assert.Nil(t, err) - id, err := r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+1)) + id, err := result.LastInsertId() + assert.Nil(t, err) + assert.Equal(t, id, tid+1) - r, err = pre.Exec("name3") - throwFail(t, err) + result, err = pre.Exec("name3") + assert.Nil(t, err) - id, err = r.LastInsertId() - throwFail(t, err) - throwFail(t, AssertIs(id, tid+2)) + id, err = result.LastInsertId() + assert.Nil(t, err) + assert.Equal(t, id, tid+2) err = pre.Close() - throwFail(t, err) + assert.Nil(t, err) res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) + assert.Nil(t, err) num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) + assert.Nil(t, err) + assert.Equal(t, num, int64(3)) } - - case IsPostgres: - - pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() - throwFail(t, err) + } else if IsPostgres { + pre, err = dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() + assert.Nil(t, err) if pre != nil { - _, err := pre.Exec("name1") - throwFail(t, err) + _, err = pre.Exec("name1") + assert.Nil(t, err) _, err = pre.Exec("name2") - throwFail(t, err) + assert.Nil(t, err) _, err = pre.Exec("name3") - throwFail(t, err) + assert.Nil(t, err) err = pre.Close() - throwFail(t, err) + assert.Nil(t, err) res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() - throwFail(t, err) + assert.Nil(t, err) if err == nil { num, err := res.RowsAffected() - throwFail(t, err) - throwFail(t, AssertIs(num, 3)) + assert.Nil(t, err) + assert.Equal(t, num, int64(3)) } } } @@ -2020,24 +2271,23 @@ func TestTransaction(t *testing.T) { // this test worked when database support transaction o := NewOrm() - err := o.Begin() + to, err := o.Begin() throwFail(t, err) - var names = []string{"1", "2", "3"} + names := []string{"1", "2", "3"} var tag Tag tag.Name = names[0] - id, err := o.Insert(&tag) + id, err := to.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) - num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) + num, err := to.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]}) throwFail(t, err) throwFail(t, AssertIs(num, 1)) - switch { - case IsMysql || IsSqlite: - res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() + if IsMysql || IsSqlite { + res, err := to.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec() throwFail(t, err) if err == nil { id, err = res.LastInsertId() @@ -2046,28 +2296,63 @@ func TestTransaction(t *testing.T) { } } - err = o.Rollback() - throwFail(t, err) - + err = to.Rollback() + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name__in", names).Count() - throwFail(t, err) - throwFail(t, AssertIs(num, 0)) + assert.Nil(t, err) + assert.Equal(t, int64(0), num) - err = o.Begin() - throwFail(t, err) + to, err = o.Begin() + assert.Nil(t, err) tag.Name = "commit" - id, err = o.Insert(&tag) - throwFail(t, err) - throwFail(t, AssertIs(id > 0, true)) + id, err = to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, id > 0) - o.Commit() - throwFail(t, err) + err = to.Commit() + assert.Nil(t, err) num, err = o.QueryTable("tag").Filter("name", "commit").Delete() - throwFail(t, err) - throwFail(t, AssertIs(num, 1)) + assert.Nil(t, err) + assert.Equal(t, int64(1), num) +} + +func TestTxOrmRollbackUnlessCommit(t *testing.T) { + o := NewOrm() + var tag Tag + // test not committed and call RollbackUnlessCommit + to, err := o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err := to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + num, err := o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(0), num) + + // test commit and call RollbackUnlessCommit + + to, err = o.Begin() + assert.Nil(t, err) + tag.Name = "rollback unless commit" + rows, err = to.Insert(&tag) + assert.Nil(t, err) + assert.True(t, rows > 0) + + err = to.Commit() + assert.Nil(t, err) + + err = to.RollbackUnlessCommit() + assert.Nil(t, err) + + num, err = o.QueryTable("tag").Filter("name", tag.Name).Delete() + assert.Nil(t, err) + assert.Equal(t, int64(1), num) } func TestTransactionIsolationLevel(t *testing.T) { @@ -2080,33 +2365,33 @@ func TestTransactionIsolationLevel(t *testing.T) { o2 := NewOrm() // start two transaction with isolation level repeatable read - err := o1.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + to1, err := o1.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) throwFail(t, err) - err = o2.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) + to2, err := o2.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) throwFail(t, err) // o1 insert tag var tag Tag tag.Name = "test-transaction" - id, err := o1.Insert(&tag) + id, err := to1.Insert(&tag) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) // o2 query tag table, no result - num, err := o2.QueryTable("tag").Filter("name", "test-transaction").Count() + num, err := to2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) // o1 commit - o1.Commit() + to1.Commit() // o2 query tag table, still no result - num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() + num, err = to2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 0)) - // o2 commit and query tag table, get the result - o2.Commit() + // o2 commit and query tag table, Get the result + to2.Commit() num, err = o2.QueryTable("tag").Filter("name", "test-transaction").Count() throwFail(t, err) throwFail(t, AssertIs(num, 1)) @@ -2119,14 +2404,14 @@ func TestTransactionIsolationLevel(t *testing.T) { func TestBeginTxWithContextCanceled(t *testing.T) { o := NewOrm() ctx, cancel := context.WithCancel(context.Background()) - o.BeginTx(ctx, nil) - id, err := o.Insert(&Tag{Name: "test-context"}) + to, _ := o.BeginWithCtx(ctx) + id, err := to.Insert(&Tag{Name: "test-context"}) throwFail(t, err) throwFail(t, AssertIs(id > 0, true)) // cancel the context before commit to make it error cancel() - err = o.Commit() + err = to.Commit() throwFail(t, AssertIs(err, context.Canceled)) } @@ -2165,8 +2450,10 @@ func TestReadOrCreate(t *testing.T) { throwFail(t, AssertIs(nu.Status, u.Status)) throwFail(t, AssertIs(nu.IsStaff, u.IsStaff)) throwFail(t, AssertIs(nu.IsActive, u.IsActive)) + assert.True(t, u.ID > 0) dORM.Delete(u) + assert.True(t, u.ID > 0) } func TestInLine(t *testing.T) { @@ -2187,8 +2474,8 @@ func TestInLine(t *testing.T) { throwFail(t, AssertIs(il.Name, name)) throwFail(t, AssertIs(il.Email, email)) - throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate)) - throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime)) + assert.True(t, il.Created.In(DefaultTimeLoc).Sub(inline.Created.In(DefaultTimeLoc)) <= time.Second) + assert.True(t, il.Updated.In(DefaultTimeLoc).Sub(inline.Updated.In(DefaultTimeLoc)) <= time.Second) } func TestInLineOneToOne(t *testing.T) { @@ -2374,7 +2661,7 @@ func TestSnake(t *testing.T) { "tag_666Name": "tag_666_name", } for name, want := range cases { - got := snakeString(name) + got := models.SnakeString(name) throwFail(t, AssertIs(got, want)) } } @@ -2387,108 +2674,348 @@ func TestIgnoreCaseTag(t *testing.T) { Name02 string `orm:"COLUMN(Name)"` Name03 string `orm:"Column(name)"` } - modelCache.clean() + models.DefaultModelCache.Clean() RegisterModel(&testTagModel{}) - info, ok := modelCache.get("test_tag_model") + info, ok := models.DefaultModelCache.Get("test_tag_model") throwFail(t, AssertIs(ok, true)) throwFail(t, AssertNot(info, nil)) if t == nil { return } - throwFail(t, AssertIs(info.fields.GetByName("NOO").column, "n")) - throwFail(t, AssertIs(info.fields.GetByName("Name01").null, true)) - throwFail(t, AssertIs(info.fields.GetByName("Name02").column, "Name")) - throwFail(t, AssertIs(info.fields.GetByName("Name03").column, "name")) + throwFail(t, AssertIs(info.Fields.GetByName("NOO").Column, "n")) + throwFail(t, AssertIs(info.Fields.GetByName("Name01").Null, true)) + throwFail(t, AssertIs(info.Fields.GetByName("Name02").Column, "Name")) + throwFail(t, AssertIs(info.Fields.GetByName("Name03").Column, "name")) } func TestInsertOrUpdate(t *testing.T) { RegisterModel(new(User)) - user := User{UserName: "unique_username133", Status: 1, Password: "o"} - user1 := User{UserName: "unique_username133", Status: 2, Password: "o"} - user2 := User{UserName: "unique_username133", Status: 3, Password: "oo"} + userName := "unique_username133" + column := "user_name" + user := User{UserName: userName, Status: 1, Password: "o"} + user1 := User{UserName: userName, Status: 2, Password: "o"} + user2 := User{UserName: userName, Status: 3, Password: "oo"} dORM.Insert(&user) - test := User{UserName: "unique_username133"} fmt.Println(dORM.Driver().Name()) if dORM.Driver().Name() == "sqlite3" { fmt.Println("sqlite3 is nonsupport") return } - //test1 - _, err := dORM.InsertOrUpdate(&user1, "user_name") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) + + specs := []struct { + description string + user User + colConflitAndArgs []string + assertion func(expected User, actual User) + isPostgresCompatible bool + }{ + { + description: "test1", + user: user1, + colConflitAndArgs: []string{column}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs(expected.Status, actual.Status)) + }, + isPostgresCompatible: true, + }, + { + description: "test2", + user: user2, + colConflitAndArgs: []string{column}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs(expected.Status, actual.Status)) + throwFailNow(t, AssertIs(expected.Password, strings.TrimSpace(actual.Password))) + }, + isPostgresCompatible: true, + }, + { + description: "test3 +", + user: user2, + colConflitAndArgs: []string{column, "status=status+1"}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs(expected.Status+1, actual.Status)) + }, + isPostgresCompatible: false, + }, + { + description: "test4 -", + user: user2, + colConflitAndArgs: []string{column, "status=status-1"}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs((expected.Status+1)-1, actual.Status)) + }, + isPostgresCompatible: false, + }, + { + description: "test5 *", + user: user2, + colConflitAndArgs: []string{column, "status=status*3"}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs(((expected.Status+1)-1)*3, actual.Status)) + }, + isPostgresCompatible: false, + }, + { + description: "test6 /", + user: user2, + colConflitAndArgs: []string{column, "Status=Status/3"}, + assertion: func(expected, actual User) { + throwFailNow(t, AssertIs((((expected.Status+1)-1)*3)/3, actual.Status)) + }, + isPostgresCompatible: false, + }, + } + + for _, spec := range specs { + // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values + if IsPostgres && !spec.isPostgresCompatible { + continue } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user1.Status, test.Status)) + + _, err := dORM.InsertOrUpdate(&spec.user, spec.colConflitAndArgs...) + if err != nil { + fmt.Println(err) + if !(err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego") { + throwFailNow(t, err) + } + continue + } + + test := User{UserName: userName} + err = dORM.Read(&test, column) + throwFailNow(t, AssertIs(err, nil)) + spec.assertion(spec.user, test) + } +} + +func TestStrPkInsert(t *testing.T) { + RegisterModel(new(StrPk)) + pk := `1` + value := `StrPkValues(*56` + strPk := &StrPk{ + Id: pk, + Value: value, + } + + var err error + _, err = dORM.Insert(strPk) + if err != ErrLastInsertIdUnavailable { + throwFailNow(t, AssertIs(err, nil)) } - //test2 - _, err = dORM.InsertOrUpdate(&user2, "user_name") + + var vForTesting StrPk + err = dORM.QueryTable(new(StrPk)).Filter(`id`, pk).One(&vForTesting) + throwFailNow(t, AssertIs(err, nil)) + throwFailNow(t, AssertIs(vForTesting.Value, value)) + + value2 := `s8s5da7as` + strPkForUpsert := &StrPk{ + Id: pk, + Value: value2, + } + + _, err = dORM.InsertOrUpdate(strPkForUpsert, `id`) if err != nil { fmt.Println(err) if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { + return + } else if err == ErrLastInsertIdUnavailable { + return } else { throwFailNow(t, err) } } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status, test.Status)) - throwFailNow(t, AssertIs(user2.Password, strings.TrimSpace(test.Password))) + var vForTesting2 StrPk + err = dORM.QueryTable(new(StrPk)).Filter(`id`, pk).One(&vForTesting2) + throwFailNow(t, AssertIs(err, nil)) + throwFailNow(t, AssertIs(vForTesting2.Value, value2)) } +} - //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values - if IsPostgres { +func TestPSQueryBuilder(t *testing.T) { + // only test postgres + if dORM.Driver().Type() != 4 { return } - //test3 + - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status+1") + + var user User + var l []userProfile + o := NewOrm() + + qb, err := NewQueryBuilder("postgres") if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(user2.Status+1, test.Status)) + throwFailNow(t, err) } - //test4 - - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status-1") + qb.Select("user.id", "user.user_name"). + From("user").Where("id = ?").OrderBy("user_name"). + Desc().Limit(1).Offset(0) + sql := qb.String() + err = o.Raw(sql, 2).QueryRow(&user) if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((user2.Status+1)-1, test.Status)) + throwFailNow(t, err) } - //test5 * - _, err = dORM.InsertOrUpdate(&user2, "user_name", "status=status*3") + throwFail(t, AssertIs(user.UserName, "slene")) + + qb.Select("*"). + From("user_profile").InnerJoin("user"). + On("user_profile.id = user.id") + sql = qb.String() + num, err := o.Raw(sql).QueryRows(&l) if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) - } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs(((user2.Status+1)-1)*3, test.Status)) + throwFailNow(t, err) } - //test6 / - _, err = dORM.InsertOrUpdate(&user2, "user_name", "Status=Status/3") - if err != nil { - fmt.Println(err) - if err.Error() == "postgres version must 9.5 or higher" || err.Error() == "`sqlite3` nonsupport InsertOrUpdate in beego" { - } else { - throwFailNow(t, err) + throwFailNow(t, AssertIs(num, 1)) + throwFailNow(t, AssertIs(l[0].UserName, "astaxie")) + throwFailNow(t, AssertIs(l[0].Age, 30)) +} + +func TestCondition(t *testing.T) { + // test Condition whether to include yourself + cond := NewCondition() + cond = cond.AndCond(cond.Or("ID", 1)) + cond = cond.AndCond(cond.Or("ID", 2)) + cond = cond.AndCond(cond.Or("ID", 3)) + cond = cond.AndCond(cond.Or("ID", 4)) + + cycleFlag := false + var hasCycle func(*Condition) + hasCycle = func(c *Condition) { + if nil == c || cycleFlag { + return + } + condPointMap := make(map[string]bool) + condPointMap[fmt.Sprintf("%p", c)] = true + for _, p := range c.params { + if p.isCond { + adr := fmt.Sprintf("%p", p.cond) + if condPointMap[adr] { + // self as sub cond was cycle + cycleFlag = true + break + } + condPointMap[adr] = true + + } + } + if cycleFlag { + return + } + for _, p := range c.params { + if p.isCond { + // check next cond + hasCycle(p.cond) + } } - } else { - dORM.Read(&test, "user_name") - throwFailNow(t, AssertIs((((user2.Status+1)-1)*3)/3, test.Status)) } + hasCycle(cond) + // cycleFlag was true,meaning use self as sub cond + throwFail(t, AssertIs(!cycleFlag, true)) +} + +func TestContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + user := User{UserName: "slene"} + + err := dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, err) + + cancel() + err = dORM.ReadWithCtx(ctx, &user, "UserName") + throwFail(t, AssertIs(err, context.Canceled)) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + + qs := dORM.QueryTable(user) + _, err = qs.Filter("UserName", "slene").CountWithCtx(ctx) + throwFail(t, AssertIs(err, context.Canceled)) +} + +func TestDebugLog(t *testing.T) { + models.DefaultModelCache.Clean() + registerAllModel() + txCommitFn := func() { + o := NewOrm() + o.DoTx(func(ctx context.Context, txOrm TxOrmer) (txerr error) { + _, txerr = txOrm.QueryTable(&User{}).Count() + return + }) + } + + txRollbackFn := func() { + o := NewOrm() + o.DoTx(func(ctx context.Context, txOrm TxOrmer) (txerr error) { + user := NewUser() + user.UserName = "slene" + user.Email = "vslene@gmail.com" + user.Password = "pass" + user.Status = 3 + user.IsStaff = true + user.IsActive = true + + txOrm.Insert(user) + txerr = fmt.Errorf("mock error") + return + }) + } + + Debug = true + output1 := captureDebugLogOutput(txCommitFn) + + assert.Contains(t, output1, "START TRANSACTION") + assert.Contains(t, output1, "COMMIT") + + output2 := captureDebugLogOutput(txRollbackFn) + + assert.Contains(t, output2, "START TRANSACTION") + assert.Contains(t, output2, "ROLLBACK") + + Debug = false + output1 = captureDebugLogOutput(txCommitFn) + assert.EqualValues(t, output1, "") + + output2 = captureDebugLogOutput(txRollbackFn) + assert.EqualValues(t, output2, "") +} + +func captureDebugLogOutput(f func()) string { + var buf bytes.Buffer + DebugLog.SetOutput(&buf) + defer func() { + DebugLog.SetOutput(os.Stderr) + }() + f() + return buf.String() +} + +func TestReadRaw(t *testing.T) { + type TestModel struct { + Id int64 + Name string + Age int8 + } + ctx, cancel := context.WithCancel(context.Background()) + RegisterModel(new(TestModel)) + testModel := TestModel{Name: "user"} + SQL := "SELECT * FROM `test_model`;" + dORM = NewOrm() + err := dORM.ReadRaw(ctx, &testModel, SQL, nil) + cancel() + fmt.Print(err) +} + +func TestExecRaw(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + SQL := "INSERT INTO `null_value`(`id`,`value`) VALUES(?,?),(?,?);" + dORM = NewOrm() + if dORM.Driver().Type() == DRPostgres { + SQL = "INSERT INTO \"null_value\"(\"id\",\"value\") VALUES($1, $2),($3, $4);" + } + res, err := dORM.ExecRaw(ctx, &NullValue{}, + SQL, []interface{}{2, "Tom", 3, "Jerry"}...) + assert.Nil(t, err) + _, err = res.RowsAffected() + assert.Nil(t, err) + cancel() } diff --git a/orm/qb.go b/client/orm/qb.go similarity index 96% rename from orm/qb.go rename to client/orm/qb.go index e0655a178e..c82d2255de 100644 --- a/orm/qb.go +++ b/client/orm/qb.go @@ -52,7 +52,7 @@ func NewQueryBuilder(driver string) (qb QueryBuilder, err error) { } else if driver == "tidb" { qb = new(TiDBQueryBuilder) } else if driver == "postgres" { - err = errors.New("postgres query builder is not supported yet") + qb = new(PostgresQueryBuilder) } else if driver == "sqlite" { err = errors.New("sqlite query builder is not supported yet") } else { diff --git a/client/orm/qb/aggregate.go b/client/orm/qb/aggregate.go new file mode 100644 index 0000000000..803d658c10 --- /dev/null +++ b/client/orm/qb/aggregate.go @@ -0,0 +1,41 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +// Aggregate represents the aggregation function, such as AVG, MAX, MIN, etc +type Aggregate struct { + fn string + arg string + alias string +} + +func (a Aggregate) selectable() {} + +func (a Aggregate) expr() {} + +func (a Aggregate) As(alias string) Aggregate { + return Aggregate{ + fn: a.fn, + arg: a.arg, + alias: alias, + } +} + +func Avg(c string) Aggregate { + return Aggregate{ + fn: "AVG", + arg: c, + } +} diff --git a/client/orm/qb/builder.go b/client/orm/qb/builder.go new file mode 100644 index 0000000000..f7ea85729d --- /dev/null +++ b/client/orm/qb/builder.go @@ -0,0 +1,105 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +import ( + "github.com/beego/beego/v2/client/orm/internal/buffers" + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/qb/errs" +) + +type builder struct { + buffer buffers.Buffer + model *models.ModelInfo + args []any +} + +func (b *builder) space() { + b.writeByte(' ') +} + +func (b *builder) writeString(val string) { + _, _ = b.buffer.WriteString(val) +} + +func (b *builder) writeByte(c byte) { + _ = b.buffer.WriteByte(c) +} + +func (b *builder) end() { + b.writeByte(';') +} + +func (b *builder) comma() { + b.writeByte(',') +} + +func (b *builder) buildPredicates(predicates []Predicate) error { + p := predicates[0] + for i := 1; i < len(predicates); i++ { + p = p.And(predicates[i]) + } + return b.buildExpression(p) +} + +func (b *builder) buildExpression(e Expression) error { + if e == nil { + return nil + } + switch exp := e.(type) { + case Column: + b.writeByte('`') + fd, ok := b.model.Fields.Fields[exp.name] + if !ok { + return errs.NewErrUnknownField(exp.name) + } + b.writeString(fd.Column) + b.writeByte('`') + case value: + b.writeByte('?') + b.args = append(b.args, exp.val) + case Predicate: + _, lp := exp.left.(Predicate) + if lp { + b.writeByte('(') + } + if err := b.buildExpression(exp.left); err != nil { + return err + } + if lp { + b.writeByte(')') + } + b.space() + b.writeString(exp.op.String()) + b.space() + + _, rp := exp.right.(Predicate) + if rp { + b.writeByte('(') + } + if err := b.buildExpression(exp.right); err != nil { + return err + } + if rp { + b.writeByte(')') + } + case RawExpr: + b.writeString(exp.raw) + b.args = append(b.args, exp.args...) + default: + return errs.NewErrUnsupportedExpressionType(exp) + } + return nil +} diff --git a/client/orm/qb/column.go b/client/orm/qb/column.go new file mode 100644 index 0000000000..0cfea25a58 --- /dev/null +++ b/client/orm/qb/column.go @@ -0,0 +1,86 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +// Column contains the info of single column +type Column struct { + name string + alias string + order string +} + +func (c Column) As(alias string) Column { + return Column{ + name: c.name, + alias: alias, + order: c.order, + } +} +func (c Column) selectable() {} + +func (c Column) expr() {} +func (c Column) Asc() Column { + return Column{ + name: c.name, + alias: c.alias, + order: " ASC", + } +} +func (c Column) Desc() Column { + return Column{ + name: c.name, + alias: c.alias, + order: " DESC", + } +} + +type value struct { + val any +} + +func (c value) expr() {} + +func valueOf(val any) value { + return value{ + val: val, + } +} + +func C(name string) Column { + return Column{ + name: name, + } +} +func (c Column) EQ(arg any) Predicate { + return Predicate{ + left: c, + op: opEqual, + right: exprOf(arg), + } +} +func (c Column) GT(arg any) Predicate { + return Predicate{ + left: c, + op: opGT, + right: exprOf(arg), + } +} +func (c Column) LT(arg any) Predicate { + return Predicate{ + left: c, + op: opLT, + right: exprOf(arg), + } +} diff --git a/client/orm/qb/delete.go b/client/orm/qb/delete.go new file mode 100644 index 0000000000..32b3375156 --- /dev/null +++ b/client/orm/qb/delete.go @@ -0,0 +1,92 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +import ( + "context" + + "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/client/orm/internal/buffers" + "github.com/beego/beego/v2/client/orm/internal/models" +) + +var _ QueryBuilder = &Deleter[any]{} + +// Deleter builds DELETE query +type Deleter[T any] struct { + builder + table interface{} + sess orm.QueryExecutor + where []Predicate +} + +// NewDeleter starts building a Delete query +func NewDeleter[T any](sess orm.QueryExecutor) *Deleter[T] { + return &Deleter[T]{ + sess: sess, + builder: builder{ + buffer: buffers.Get(), + }, + } +} + +func (d *Deleter[T]) Build() (*Query, error) { + defer buffers.Put(d.buffer) + d.writeString("DELETE FROM ") + var err error + if d.table == nil { + d.table = new(T) + } + registry := models.DefaultModelCache + d.model, err = registry.GetOrRegisterByMd(d.table) + if err != nil { + return nil, err + } + d.writeByte('`') + d.writeString(d.model.Table) + d.writeByte('`') + if len(d.where) > 0 { + d.writeString(" WHERE ") + err = d.buildPredicates(d.where) + if err != nil { + return nil, err + } + } + d.end() + return &Query{SQL: d.buffer.String(), Args: d.args}, nil +} + +// From accepts model definition +func (d *Deleter[T]) From(table interface{}) *Deleter[T] { + d.table = table + return d +} + +// Where accepts predicates +func (d *Deleter[T]) Where(predicates ...Predicate) *Deleter[T] { + d.where = predicates + return d +} + +// Exec sql +func (d *Deleter[T]) Exec(ctx context.Context) Result { + q, err := d.Build() + if err != nil { + return Result{err: err} + } + t := new(T) + res, err := d.sess.ExecRaw(ctx, t, q.SQL, q.Args...) + return Result{res: res, err: err} +} diff --git a/client/orm/qb/delete_test.go b/client/orm/qb/delete_test.go new file mode 100644 index 0000000000..c91eb5f8f3 --- /dev/null +++ b/client/orm/qb/delete_test.go @@ -0,0 +1,91 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" +) + +func TestDeleter_Build(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + testCases := []struct { + name string + builder QueryBuilder + wantQuery *Query + wantErr error + }{ + { + name: "no where", + builder: NewDeleter[TestModel](db).From(&TestModel{}), + wantQuery: &Query{ + SQL: "DELETE FROM `test_model`;", + }, + }, + { + name: "where", + builder: NewDeleter[TestModel](db).Where(C("Id").EQ(16)), + wantQuery: &Query{ + SQL: "DELETE FROM `test_model` WHERE `id` = ?;", + Args: []interface{}{16}, + }, + }, + { + name: "no where combination", + builder: NewDeleter[TestCombinedModel](db).From(&TestCombinedModel{}), + wantQuery: &Query{ + SQL: "DELETE FROM `test_combined_model`;", + }, + }, + { + name: "where combination", + builder: NewDeleter[TestCombinedModel](db).Where(C("CreateTime").EQ(uint64(1000))), + wantQuery: &Query{ + SQL: "DELETE FROM `test_combined_model` WHERE `create_time` = ?;", + Args: []interface{}{uint64(1000)}, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query, err := tc.builder.Build() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantQuery, query) + }) + } +} + +type BaseEntity struct { + CreateTime uint64 + UpdateTime uint64 +} + +type TestCombinedModel struct { + BaseEntity + Id int64 `eorm:"primary_key"` + FirstName string + Age int8 + LastName *string +} diff --git a/client/orm/qb/errs/error.go b/client/orm/qb/errs/error.go new file mode 100644 index 0000000000..5a50241299 --- /dev/null +++ b/client/orm/qb/errs/error.go @@ -0,0 +1,33 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package errs + +import ( + "errors" + "fmt" +) + +var ErrGetByMd = errors.New("orm: Unknown error in get model") + +// NewErrUnknownField returns an error representing an unknown field +// Generally, it means that you may have entered a column name or an incorrect field name +func NewErrUnknownField(fd string) error { + return fmt.Errorf("orm: Unknown field %s", fd) +} + +// NewErrUnsupportedExpressionType returns an error message that does not support the expression +func NewErrUnsupportedExpressionType(exp any) error { + return fmt.Errorf("orm: Unsupported expression %v", exp) +} diff --git a/client/orm/qb/expression.go b/client/orm/qb/expression.go new file mode 100644 index 0000000000..47c311cfcd --- /dev/null +++ b/client/orm/qb/expression.go @@ -0,0 +1,40 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +// RawExpr Represents a native expression +// It means that ORM will not handle it in any way +type RawExpr struct { + raw string + args []interface{} +} + +func (r RawExpr) selectable() {} + +func (r RawExpr) expr() {} + +func (r RawExpr) AsPredicate() Predicate { + return Predicate{ + left: r, + } +} + +// Raw Create a RawExpr +func Raw(expr string, args ...interface{}) RawExpr { + return RawExpr{ + raw: expr, + args: args, + } +} diff --git a/client/orm/qb/predicate.go b/client/orm/qb/predicate.go new file mode 100644 index 0000000000..68db253550 --- /dev/null +++ b/client/orm/qb/predicate.go @@ -0,0 +1,73 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +type op string + +const ( + opEqual op = "=" + opLT op = "<" + opGT op = ">" + opAnd op = "AND" + opOr op = "OR" + opNot op = "NOT" +) + +func (o op) String() string { + return string(o) +} + +// Predicate Represents a query condition +type Predicate struct { + left Expression + op op + right Expression +} + +// Expression Represents a statement +type Expression interface { + expr() +} + +func (Predicate) expr() {} + +func exprOf(e any) Expression { + switch exp := e.(type) { + case Expression: + return exp + default: + return valueOf(exp) + } +} +func (p Predicate) And(r Predicate) Predicate { + return Predicate{ + left: p, + op: opAnd, + right: r, + } +} +func (p Predicate) Or(r Predicate) Predicate { + return Predicate{ + left: p, + op: opOr, + right: r, + } +} +func Not(r Predicate) Predicate { + return Predicate{ + op: opNot, + right: r, + } +} diff --git a/client/orm/qb/query.go b/client/orm/qb/query.go new file mode 100644 index 0000000000..f2b30d1d30 --- /dev/null +++ b/client/orm/qb/query.go @@ -0,0 +1,40 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +import ( + "context" + "database/sql" +) + +type Querier[T any] interface { + Get(ctx context.Context) (*T, error) + // batch query + GetMulti(ctx context.Context) (*T, error) +} + +// Executor is used for INSERT , DELETE , UPDATE +type Executor interface { + Exec(ctx context.Context) (sql.Result, error) +} + +type Query struct { + SQL string + Args []any +} + +type QueryBuilder interface { + Build() (*Query, error) +} diff --git a/client/orm/qb/result.go b/client/orm/qb/result.go new file mode 100644 index 0000000000..58251c9792 --- /dev/null +++ b/client/orm/qb/result.go @@ -0,0 +1,40 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +import "database/sql" + +type Result struct { + err error + res sql.Result +} + +func (r Result) Err() error { + return r.err +} + +func (r Result) LastInsertId() (int64, error) { + if r.err != nil { + return 0, r.err + } + return r.res.LastInsertId() +} + +func (r Result) RowsAffected() (int64, error) { + if r.err != nil { + return 0, r.err + } + return r.res.RowsAffected() +} diff --git a/client/orm/qb/result_test.go b/client/orm/qb/result_test.go new file mode 100644 index 0000000000..77fa9b8472 --- /dev/null +++ b/client/orm/qb/result_test.go @@ -0,0 +1,93 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +import ( + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +func TestResult_RowsAffected(t *testing.T) { + testCases := []struct { + name string + res Result + wantAffected int64 + wantErr error + }{ + { + name: "err", + wantErr: errors.New("exec err"), + res: Result{err: errors.New("exec err")}, + }, + { + name: "unknown error", + wantErr: errors.New("unknown error"), + res: Result{res: sqlmock.NewErrorResult(errors.New("unknown error"))}, + }, + { + name: "no err", + wantAffected: int64(234), + res: Result{res: sqlmock.NewResult(123, 234)}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + affected, err := tc.res.RowsAffected() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantAffected, affected) + }) + } +} + +func TestResult_LastInsertId(t *testing.T) { + testCases := []struct { + name string + res Result + wantLastId int64 + wantErr error + }{ + { + name: "err", + wantErr: errors.New("exec err"), + res: Result{err: errors.New("exec err")}, + }, + { + name: "res err", + wantErr: errors.New("exec err"), + res: Result{res: sqlmock.NewErrorResult(errors.New("exec err"))}, + }, + { + name: "no err", + wantLastId: int64(123), + res: Result{res: sqlmock.NewResult(123, 234)}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + id, err := tc.res.LastInsertId() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantLastId, id) + }) + } +} diff --git a/client/orm/qb/select.go b/client/orm/qb/select.go new file mode 100644 index 0000000000..a99281fe3b --- /dev/null +++ b/client/orm/qb/select.go @@ -0,0 +1,260 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +import ( + "context" + "errors" + "reflect" + + "github.com/beego/beego/v2/client/orm/internal/buffers" + + "github.com/beego/beego/v2/client/orm" + + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/qb/errs" +) + +var _ QueryBuilder = &Selector[any]{} + +// The Selector is used to construct a SELECT statement +type Selector[T any] struct { + builder + orderBy []Column + where []Predicate + offset int + limit int + columns []Selectable + tableName string + sess orm.QueryExecutor + // allow users to specify the registry + registry *models.ModelCache +} + +func NewSelector[T any](sess orm.QueryExecutor) *Selector[T] { + return &Selector[T]{ + sess: sess, + builder: builder{ + buffer: buffers.Get(), + }, + registry: models.DefaultModelCache, + } +} + +func (s *Selector[T]) Build() (*Query, error) { + var ( + t T + err error + ) + defer buffers.Put(s.buffer) + s.model, err = s.registry.GetOrRegisterByMd(&t) + if err != nil { + return nil, err + } + s.writeString("SELECT ") + if err = s.buildColumns(); err != nil { + return nil, err + } + s.writeString(" FROM ") + s.buildTable() + if len(s.where) > 0 { + s.writeString(" WHERE ") + err = s.buildPredicates(s.where) + if err != nil { + return nil, err + } + } + if len(s.orderBy) > 0 { + s.writeString(" ORDER BY ") + for i, c := range s.orderBy { + if i > 0 { + s.comma() + } + if err = s.buildColumn(c, false); err != nil { + return nil, err + } + } + } + if s.limit > 0 { + s.writeString(" LIMIT ?") + s.addArgs(s.limit) + } + if s.offset > 0 { + s.writeString(" OFFSET ?") + s.addArgs(s.offset) + } + + s.end() + return &Query{ + SQL: s.buffer.String(), + Args: s.args, + }, nil +} + +func (s *Selector[T]) buildTable() { + if s.tableName != "" { + s.writeByte('`') + s.writeString(s.tableName) + s.writeByte('`') + } else { + if s.model.Table == "" { + var t T + typ := reflect.TypeOf(t) + s.writeByte('`') + s.writeString(typ.Name()) + s.writeByte('`') + } else { + s.writeByte('`') + s.writeString(s.model.Table) + s.writeByte('`') + } + } +} + +func (s *Selector[T]) addArgs(args ...any) { + if s.args == nil { + s.args = make([]any, 0, 8) + } + s.args = append(s.args, args...) +} + +func (s *Selector[T]) buildColumn(c Column, useAlias bool) error { + s.writeByte('`') + fd, ok := s.model.Fields.Fields[c.name] + if !ok { + return errs.NewErrUnknownField(c.name) + } + s.writeString(fd.Column) + s.writeByte('`') + if c.order != "" { + s.writeString(c.order) + } + if useAlias { + s.buildAs(c.alias) + } + return nil +} + +func (s *Selector[T]) buildColumns() error { + if len(s.columns) == 0 { + s.writeByte('*') + return nil + } + for i, c := range s.columns { + if i > 0 { + s.comma() + } + switch val := c.(type) { + case Column: + if err := s.buildColumn(val, true); err != nil { + return err + } + case Aggregate: + if err := s.buildAggregate(val, true); err != nil { + return err + } + case RawExpr: + s.writeString(val.raw) + if len(val.args) != 0 { + s.addArgs(val.args...) + } + default: + return errors.New("orm: Unsupported target column") + } + } + return nil +} + +func (s *Selector[T]) buildAggregate(a Aggregate, useAlias bool) error { + s.writeString(a.fn) + s.writeString("(`") + fd, ok := s.model.Fields.Fields[a.arg] + if !ok { + return errs.NewErrUnknownField(a.arg) + } + s.writeString(fd.Column) + s.writeString("`)") + if useAlias { + s.buildAs(a.alias) + } + return nil +} + +func (s *Selector[T]) From(table string) *Selector[T] { + s.tableName = table + return s +} + +func (s *Selector[T]) Where(ps ...Predicate) *Selector[T] { + s.where = append(s.where, ps...) + return s +} + +func (s *Selector[T]) OrderBy(cols ...Column) *Selector[T] { + s.orderBy = cols + return s +} + +func (s *Selector[T]) Offset(offset int) *Selector[T] { + s.offset = offset + return s +} + +func (s *Selector[T]) Limit(limit int) *Selector[T] { + s.limit = limit + return s +} + +func (s *Selector[T]) Select(cols ...Selectable) *Selector[T] { + s.columns = cols + return s +} + +type Selectable interface { + selectable() +} + +func (s *Selector[T]) buildAs(alias string) { + if alias != "" { + s.writeString(" AS ") + s.writeByte('`') + s.writeString(alias) + s.writeByte('`') + } +} + +func (s *Selector[T]) Get(ctx context.Context) (*T, error) { + q, err := s.Build() + if err != nil { + return nil, err + } + t := new(T) + err = s.sess.ReadRaw(ctx, t, q.SQL, q.Args...) + return t, nil +} + +// The WhereRaw change the case like where do. +func (s *Selector[T]) WhereMap(conds map[string]any) *Selector[T] { + for col, val := range conds { + ps := C(col).EQ(val) + s.Where(ps) + } + return s +} + +// The WhereRaw does not change the case like where do. +func (s *Selector[T]) WhereRaw(raw string, args ...any) *Selector[T] { + return s.Where(Raw(raw, args...).AsPredicate()) +} diff --git a/client/orm/qb/select_test.go b/client/orm/qb/select_test.go new file mode 100644 index 0000000000..96cb9df6f9 --- /dev/null +++ b/client/orm/qb/select_test.go @@ -0,0 +1,451 @@ +// Copyright 2023 beego. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qb + +import ( + "database/sql" + "golang.org/x/net/context" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/client/orm" + "github.com/beego/beego/v2/client/orm/qb/errs" +) + +func TestTx_Commit(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + + tx, err := db.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{}) + assert.Nil(t, err) + + q := NewSelector[TestModel](tx).From("test_model").Where(C("Id").EQ(1)) + query, er := q.Build() + assert.Nil(t, er) + assert.Equal(t, "SELECT * FROM `test_model` WHERE `id` = ?;", query.SQL) + + err = tx.Commit() + assert.Nil(t, err) +} + +func TestTx_Rollback(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + + tx, err := db.BeginWithCtxAndOpts(context.Background(), &sql.TxOptions{}) + assert.Nil(t, err) + + q := NewSelector[TestModel](tx).From("test_model").Where(C("Id").EQ(1)) + query, er := q.Build() + assert.Nil(t, er) + assert.Equal(t, "SELECT * FROM `test_model` WHERE `id` = ?;", query.SQL) + + err = tx.Rollback() + assert.Nil(t, err) +} + +func TestExampleTransactionSelector(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + err = db.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error { + q := NewSelector[TestModel](txOrm).From("test_model").Where(C("Id").EQ(1)) + query, er := q.Build() + if er != nil { + return er + } + assert.Equal(t, "SELECT * FROM `test_model` WHERE `id` = ?;", query.SQL) + return nil + }) + assert.Nil(t, err) +} + +func TestSelector_RawAndWhereMap(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + testCase := []struct { + name string + q QueryBuilder + wantQuery *Query + wantErr error + }{ + { + name: "WhereRaw", + q: NewSelector[TestModel](db).WhereRaw("`age` = ? and `first_name` = ?", 18, "sep"), + wantQuery: &Query{ + // There are two spaces at the end because we use predicate but not predicate.op + SQL: "SELECT * FROM `test_model` WHERE `age` = ? and `first_name` = ? ;", + Args: []any{18, "sep"}, + }, + }, + // The WhereMap test might fail because the traversal of the map is unordered. + { + name: "WhereMap", + q: NewSelector[TestModel](db).WhereMap(map[string]any{"Age": 18}), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` WHERE `age` = ?;", + Args: []any{18}, + }, + }, + { + name: "WhereMapAndWhereRaw", + q: NewSelector[TestModel](db).WhereMap(map[string]any{"Age": 18}).WhereRaw("`id` = ? and `last_name` = ?", 1, "join"), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` WHERE (`age` = ?) AND (`id` = ? and `last_name` = ? );", + Args: []any{18, 1, "join"}, + }, + }, + + { + name: "Where_WhereMap_WhereRaw", + q: NewSelector[TestModel](db).Where(C("LastName").EQ("join")).WhereMap(map[string]any{"Age": 18}).WhereRaw("`id` = ?", 1), + wantQuery: &Query{ + // There are two spaces before NOT because we did not perform any special processing on NOT + SQL: "SELECT * FROM `test_model` WHERE ((`last_name` = ?) AND (`age` = ?)) AND (`id` = ? );", + Args: []any{"join", 18, 1}, + }, + }, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + q, err := tc.q.Build() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantQuery, q) + }) + } +} + +func TestSelector_Build(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + testCase := []struct { + name string + q QueryBuilder + wantQuery *Query + wantErr error + }{ + { + name: "no from", + q: NewSelector[TestModel](db), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model`;", + Args: nil, + }, + }, { + name: "from", + q: NewSelector[TestModel](db).From("from_test"), + wantQuery: &Query{ + SQL: "SELECT * FROM `from_test`;", + Args: nil, + }, + }, + { + name: "no from", + q: NewSelector[TestModel](db).From(""), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model`;", + Args: nil, + }, + }, { + name: "test_db", + q: NewSelector[TestModel](db).From("`test_db`.`db_model`"), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_db`.`db_model`;", + Args: nil, + }, + }, { + name: "single and simple predicate", + q: NewSelector[TestModel](db).From("`test_model_t`"). + Where(C("Id").EQ(1)), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model_t` WHERE `Id` = ?;", + Args: []any{1}, + }, + }, + { + name: "multiple predicates", + q: NewSelector[TestModel](db). + Where(C("Age").GT(18), C("Age").LT(35)), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` WHERE (`Age` > ?) AND (`Age` < ?);", + Args: []any{18, 35}, + }, + }, + { + name: "and", + q: NewSelector[TestModel](db). + Where(C("Age").GT(18).And(C("Age").LT(35))), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` WHERE (`Age` > ?) AND (`Age` < ?);", + Args: []any{18, 35}, + }, + }, + { + name: "or", + q: NewSelector[TestModel](db). + Where(C("Age").GT(18).Or(C("Age").LT(35))), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` WHERE (`Age` > ?) OR (`Age` < ?);", + Args: []any{18, 35}, + }, + }, + { + name: "not", + q: NewSelector[TestModel](db).Where(Not(C("Age").GT(18))), + wantQuery: &Query{ + // There are two spaces before NOT because we did not perform any special processing on NOT + SQL: "SELECT * FROM `test_model` WHERE NOT (`Age` > ?);", + Args: []any{18}, + }, + }, + } + + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + q, err := tc.q.Build() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantQuery, q) + }) + } +} + +func TestSelector_OffsetLimit(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + testCases := []struct { + name string + q QueryBuilder + wantQuery *Query + wantErr error + }{ + { + name: "offset only", + q: NewSelector[TestModel](db).Offset(10), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` OFFSET ?;", + Args: []any{10}, + }, + }, + { + name: "limit only", + q: NewSelector[TestModel](db).Limit(10), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` LIMIT ?;", + Args: []any{10}, + }, + }, + { + name: "limit offset", + q: NewSelector[TestModel](db).Limit(20).Offset(10), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` LIMIT ? OFFSET ?;", + Args: []any{20, 10}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query, err := tc.q.Build() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantQuery, query) + }) + } +} + +func TestSelector_OrderBy(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + testCases := []struct { + name string + q QueryBuilder + wantQuery *Query + wantErr error + }{ + { + name: "none", + q: NewSelector[TestModel](db).OrderBy(), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model`;", + }, + }, + { + name: "single", + q: NewSelector[TestModel](db).OrderBy(C("Age")), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` ORDER BY `age`;", + }, + }, + { + name: "single asc", + q: NewSelector[TestModel](db).OrderBy(C("Age").Asc()), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` ORDER BY `age` ASC;", + }, + }, + { + name: "single desc", + q: NewSelector[TestModel](db).OrderBy(C("Age").Desc()), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` ORDER BY `age` DESC;", + }, + }, + { + name: "multiple", + q: NewSelector[TestModel](db).OrderBy(C("Age").Asc(), C("FirstName").Desc()), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` ORDER BY `age` ASC,`first_name` DESC;", + }, + }, + { + name: "multiple asc", + q: NewSelector[TestModel](db).OrderBy(C("Age"), C("FirstName")), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` ORDER BY `age`,`first_name`;", + }, + }, + { + name: "invalid column", + q: NewSelector[TestModel](db).OrderBy(C("Invalid")), + wantErr: errs.NewErrUnknownField("Invalid"), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query, err := tc.q.Build() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantQuery, query) + }) + } +} + +func TestSelector_Select(t *testing.T) { + err := orm.RegisterDataBase("default", "sqlite3", "") + if err != nil { + return + } + db := orm.NewOrm() + testCases := []struct { + name string + q QueryBuilder + wantQuery *Query + wantErr error + }{ + { + name: "all", + q: NewSelector[TestModel](db), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model`;", + }, + }, + { + name: "invalid column", + q: NewSelector[TestModel](db).Select(Avg("Invalid")), + wantErr: errs.NewErrUnknownField("Invalid"), + }, + { + name: "partial columns", + q: NewSelector[TestModel](db).Select(C("Id"), C("FirstName")), + wantQuery: &Query{ + SQL: "SELECT `id`,`first_name` FROM `test_model`;", + }, + }, + { + name: "avg", + q: NewSelector[TestModel](db).Select(Avg("Age")), + wantQuery: &Query{ + SQL: "SELECT AVG(`age`) FROM `test_model`;", + }, + }, + { + name: "raw expression", + q: NewSelector[TestModel](db).Select(Raw("COUNT(DISTINCT `first_name`)")), + wantQuery: &Query{ + SQL: "SELECT COUNT(DISTINCT `first_name`) FROM `test_model`;", + }, + }, + { + name: "alias", + q: NewSelector[TestModel](db). + Select(C("Id").As("my_id"), + Avg("Age").As("avg_age")), + wantQuery: &Query{ + SQL: "SELECT `id` AS `my_id`,AVG(`age`) AS `avg_age` FROM `test_model`;", + }, + }, + { + name: "where ignore alias", + q: NewSelector[TestModel](db). + Where(C("Id").As("my_id").LT(100)), + wantQuery: &Query{ + SQL: "SELECT * FROM `test_model` WHERE `id` < ?;", + Args: []any{100}, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query, err := tc.q.Build() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + assert.Equal(t, tc.wantQuery, query) + }) + } +} + +type TestModel struct { + Id int64 + // "" + FirstName string + Age int8 + LastName sql.NullString +} diff --git a/orm/qb_mysql.go b/client/orm/qb_mysql.go similarity index 69% rename from orm/qb_mysql.go rename to client/orm/qb_mysql.go index 23bdc9eef9..486299a9d6 100644 --- a/orm/qb_mysql.go +++ b/client/orm/qb_mysql.go @@ -25,144 +25,144 @@ const CommaSpace = ", " // MySQLQueryBuilder is the SQL build type MySQLQueryBuilder struct { - Tokens []string + tokens []string } -// Select will join the fields +// Select will join the Fields func (qb *MySQLQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) + qb.tokens = append(qb.tokens, "SELECT", strings.Join(fields, CommaSpace)) return qb } // ForUpdate add the FOR UPDATE clause func (qb *MySQLQueryBuilder) ForUpdate() QueryBuilder { - qb.Tokens = append(qb.Tokens, "FOR UPDATE") + qb.tokens = append(qb.tokens, "FOR UPDATE") return qb } // From join the tables func (qb *MySQLQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) + qb.tokens = append(qb.tokens, "FROM", strings.Join(tables, CommaSpace)) return qb } // InnerJoin INNER JOIN the table func (qb *MySQLQueryBuilder) InnerJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INNER JOIN", table) + qb.tokens = append(qb.tokens, "INNER JOIN", table) return qb } // LeftJoin LEFT JOIN the table func (qb *MySQLQueryBuilder) LeftJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) + qb.tokens = append(qb.tokens, "LEFT JOIN", table) return qb } // RightJoin RIGHT JOIN the table func (qb *MySQLQueryBuilder) RightJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) + qb.tokens = append(qb.tokens, "RIGHT JOIN", table) return qb } // On join with on cond func (qb *MySQLQueryBuilder) On(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ON", cond) + qb.tokens = append(qb.tokens, "ON", cond) return qb } // Where join the Where cond func (qb *MySQLQueryBuilder) Where(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "WHERE", cond) + qb.tokens = append(qb.tokens, "WHERE", cond) return qb } // And join the and cond func (qb *MySQLQueryBuilder) And(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "AND", cond) + qb.tokens = append(qb.tokens, "AND", cond) return qb } // Or join the or cond func (qb *MySQLQueryBuilder) Or(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OR", cond) + qb.tokens = append(qb.tokens, "OR", cond) return qb } // In join the IN (vals) func (qb *MySQLQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + qb.tokens = append(qb.tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") return qb } -// OrderBy join the Order by fields +// OrderBy join the Order by Fields func (qb *MySQLQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) + qb.tokens = append(qb.tokens, "ORDER BY", strings.Join(fields, CommaSpace)) return qb } // Asc join the asc func (qb *MySQLQueryBuilder) Asc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "ASC") + qb.tokens = append(qb.tokens, "ASC") return qb } // Desc join the desc func (qb *MySQLQueryBuilder) Desc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "DESC") + qb.tokens = append(qb.tokens, "DESC") return qb } // Limit join the limit num func (qb *MySQLQueryBuilder) Limit(limit int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) + qb.tokens = append(qb.tokens, "LIMIT", strconv.Itoa(limit)) return qb } // Offset join the offset num func (qb *MySQLQueryBuilder) Offset(offset int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) + qb.tokens = append(qb.tokens, "OFFSET", strconv.Itoa(offset)) return qb } -// GroupBy join the Group by fields +// GroupBy join the Group by Fields func (qb *MySQLQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) + qb.tokens = append(qb.tokens, "GROUP BY", strings.Join(fields, CommaSpace)) return qb } // Having join the Having cond func (qb *MySQLQueryBuilder) Having(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "HAVING", cond) + qb.tokens = append(qb.tokens, "HAVING", cond) return qb } // Update join the update table func (qb *MySQLQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) + qb.tokens = append(qb.tokens, "UPDATE", strings.Join(tables, CommaSpace)) return qb } -// Set join the set kv +// Set join the Set kv func (qb *MySQLQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) + qb.tokens = append(qb.tokens, "SET", strings.Join(kv, CommaSpace)) return qb } // Delete join the Delete tables func (qb *MySQLQueryBuilder) Delete(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "DELETE") + qb.tokens = append(qb.tokens, "DELETE") if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) + qb.tokens = append(qb.tokens, strings.Join(tables, CommaSpace)) } return qb } // InsertInto join the insert SQL func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INSERT INTO", table) + qb.tokens = append(qb.tokens, "INSERT INTO", table) if len(fields) != 0 { fieldsStr := strings.Join(fields, CommaSpace) - qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") + qb.tokens = append(qb.tokens, "(", fieldsStr, ")") } return qb } @@ -170,7 +170,7 @@ func (qb *MySQLQueryBuilder) InsertInto(table string, fields ...string) QueryBui // Values join the Values(vals) func (qb *MySQLQueryBuilder) Values(vals ...string) QueryBuilder { valsStr := strings.Join(vals, CommaSpace) - qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") + qb.tokens = append(qb.tokens, "VALUES", "(", valsStr, ")") return qb } @@ -179,7 +179,9 @@ func (qb *MySQLQueryBuilder) Subquery(sub string, alias string) string { return fmt.Sprintf("(%s) AS %s", sub, alias) } -// String join all Tokens +// String join All tokens func (qb *MySQLQueryBuilder) String() string { - return strings.Join(qb.Tokens, " ") + s := strings.Join(qb.tokens, " ") + qb.tokens = qb.tokens[:0] + return s } diff --git a/client/orm/qb_postgres.go b/client/orm/qb_postgres.go new file mode 100644 index 0000000000..713fb01498 --- /dev/null +++ b/client/orm/qb_postgres.go @@ -0,0 +1,219 @@ +package orm + +import ( + "fmt" + "strconv" + "strings" +) + +var quote string = `"` + +// PostgresQueryBuilder is the SQL build +type PostgresQueryBuilder struct { + tokens []string +} + +func processingStr(str []string) string { + s := strings.Join(str, `","`) + s = fmt.Sprintf("%s%s%s", quote, s, quote) + return s +} + +// Select will join the Fields +func (qb *PostgresQueryBuilder) Select(fields ...string) QueryBuilder { + var str string + n := len(fields) + + if fields[0] == "*" { + str = "*" + } else { + for i := 0; i < n; i++ { + sli := strings.Split(fields[i], ".") + s := strings.Join(sli, `"."`) + s = fmt.Sprintf("%s%s%s", quote, s, quote) + if n == 1 || i == n-1 { + str += s + } else { + str += s + "," + } + } + } + + qb.tokens = append(qb.tokens, "SELECT", str) + return qb +} + +// ForUpdate add the FOR UPDATE clause +func (qb *PostgresQueryBuilder) ForUpdate() QueryBuilder { + qb.tokens = append(qb.tokens, "FOR UPDATE") + return qb +} + +// From join the tables +func (qb *PostgresQueryBuilder) From(tables ...string) QueryBuilder { + str := processingStr(tables) + qb.tokens = append(qb.tokens, "FROM", str) + return qb +} + +// InnerJoin INNER JOIN the table +func (qb *PostgresQueryBuilder) InnerJoin(table string) QueryBuilder { + str := fmt.Sprintf("%s%s%s", quote, table, quote) + qb.tokens = append(qb.tokens, "INNER JOIN", str) + return qb +} + +// LeftJoin LEFT JOIN the table +func (qb *PostgresQueryBuilder) LeftJoin(table string) QueryBuilder { + str := fmt.Sprintf("%s%s%s", quote, table, quote) + qb.tokens = append(qb.tokens, "LEFT JOIN", str) + return qb +} + +// RightJoin RIGHT JOIN the table +func (qb *PostgresQueryBuilder) RightJoin(table string) QueryBuilder { + str := fmt.Sprintf("%s%s%s", quote, table, quote) + qb.tokens = append(qb.tokens, "RIGHT JOIN", str) + return qb +} + +// On join with on cond +func (qb *PostgresQueryBuilder) On(cond string) QueryBuilder { + var str string + cond = strings.Replace(cond, " ", "", -1) + slice := strings.Split(cond, "=") + for i := 0; i < len(slice); i++ { + sli := strings.Split(slice[i], ".") + s := strings.Join(sli, `"."`) + s = fmt.Sprintf("%s%s%s", quote, s, quote) + if i == 0 { + str = s + " =" + " " + } else { + str += s + } + } + + qb.tokens = append(qb.tokens, "ON", str) + return qb +} + +// Where join the Where cond +func (qb *PostgresQueryBuilder) Where(cond string) QueryBuilder { + qb.tokens = append(qb.tokens, "WHERE", cond) + return qb +} + +// And join the and cond +func (qb *PostgresQueryBuilder) And(cond string) QueryBuilder { + qb.tokens = append(qb.tokens, "AND", cond) + return qb +} + +// Or join the or cond +func (qb *PostgresQueryBuilder) Or(cond string) QueryBuilder { + qb.tokens = append(qb.tokens, "OR", cond) + return qb +} + +// In join the IN (vals) +func (qb *PostgresQueryBuilder) In(vals ...string) QueryBuilder { + qb.tokens = append(qb.tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") + return qb +} + +// OrderBy join the Order by Fields +func (qb *PostgresQueryBuilder) OrderBy(fields ...string) QueryBuilder { + str := processingStr(fields) + qb.tokens = append(qb.tokens, "ORDER BY", str) + return qb +} + +// Asc join the asc +func (qb *PostgresQueryBuilder) Asc() QueryBuilder { + qb.tokens = append(qb.tokens, "ASC") + return qb +} + +// Desc join the desc +func (qb *PostgresQueryBuilder) Desc() QueryBuilder { + qb.tokens = append(qb.tokens, "DESC") + return qb +} + +// Limit join the limit num +func (qb *PostgresQueryBuilder) Limit(limit int) QueryBuilder { + qb.tokens = append(qb.tokens, "LIMIT", strconv.Itoa(limit)) + return qb +} + +// Offset join the offset num +func (qb *PostgresQueryBuilder) Offset(offset int) QueryBuilder { + qb.tokens = append(qb.tokens, "OFFSET", strconv.Itoa(offset)) + return qb +} + +// GroupBy join the Group by Fields +func (qb *PostgresQueryBuilder) GroupBy(fields ...string) QueryBuilder { + str := processingStr(fields) + qb.tokens = append(qb.tokens, "GROUP BY", str) + return qb +} + +// Having join the Having cond +func (qb *PostgresQueryBuilder) Having(cond string) QueryBuilder { + qb.tokens = append(qb.tokens, "HAVING", cond) + return qb +} + +// Update join the update table +func (qb *PostgresQueryBuilder) Update(tables ...string) QueryBuilder { + str := processingStr(tables) + qb.tokens = append(qb.tokens, "UPDATE", str) + return qb +} + +// Set join the Set kv +func (qb *PostgresQueryBuilder) Set(kv ...string) QueryBuilder { + qb.tokens = append(qb.tokens, "SET", strings.Join(kv, CommaSpace)) + return qb +} + +// Delete join the Delete tables +func (qb *PostgresQueryBuilder) Delete(tables ...string) QueryBuilder { + qb.tokens = append(qb.tokens, "DELETE") + if len(tables) != 0 { + str := processingStr(tables) + qb.tokens = append(qb.tokens, str) + } + return qb +} + +// InsertInto join the insert SQL +func (qb *PostgresQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { + str := fmt.Sprintf("%s%s%s", quote, table, quote) + qb.tokens = append(qb.tokens, "INSERT INTO", str) + if len(fields) != 0 { + fieldsStr := strings.Join(fields, CommaSpace) + qb.tokens = append(qb.tokens, "(", fieldsStr, ")") + } + return qb +} + +// Values join the Values(vals) +func (qb *PostgresQueryBuilder) Values(vals ...string) QueryBuilder { + valsStr := strings.Join(vals, CommaSpace) + qb.tokens = append(qb.tokens, "VALUES", "(", valsStr, ")") + return qb +} + +// Subquery join the sub as alias +func (qb *PostgresQueryBuilder) Subquery(sub string, alias string) string { + return fmt.Sprintf("(%s) AS %s", sub, alias) +} + +// String join All tokens +func (qb *PostgresQueryBuilder) String() string { + s := strings.Join(qb.tokens, " ") + qb.tokens = qb.tokens[:0] + return s +} diff --git a/client/orm/qb_tidb.go b/client/orm/qb_tidb.go new file mode 100644 index 0000000000..772edb5d50 --- /dev/null +++ b/client/orm/qb_tidb.go @@ -0,0 +1,21 @@ +// Copyright 2015 TiDB Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +// TiDBQueryBuilder is the SQL build +type TiDBQueryBuilder struct { + MySQLQueryBuilder + tokens []string +} diff --git a/client/orm/types.go b/client/orm/types.go new file mode 100644 index 0000000000..984225cd78 --- /dev/null +++ b/client/orm/types.go @@ -0,0 +1,669 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "context" + "database/sql" + "reflect" + "time" + + "github.com/beego/beego/v2/client/orm/internal/models" + + "github.com/beego/beego/v2/client/orm/clauses/order_clause" + "github.com/beego/beego/v2/core/utils" +) + +// TableNameI is usually used by model +// when you custom your table name, please implement this interfaces +// for example: +// +// type User struct { +// ... +// } +// +// func (u *User) TableName() string { +// return "USER_TABLE" +// } +type TableNameI interface { + TableName() string +} + +// TableEngineI is usually used by model +// when you want to use specific engine, like myisam, you can implement this interface +// for example: +// +// type User struct { +// ... +// } +// +// func (u *User) TableEngine() string { +// return "myisam" +// } +type TableEngineI interface { + TableEngine() string +} + +// TableIndexI is usually used by model +// when you want to create indexes, you can implement this interface +// for example: +// +// type User struct { +// ... +// } +// +// func (u *User) TableIndex() [][]string { +// return [][]string{{"Name"}} +// } +type TableIndexI interface { + TableIndex() [][]string +} + +// TableUniqueI is usually used by model +// when you want to create unique indexes, you can implement this interface +// for example: +// +// type User struct { +// ... +// } +// +// func (u *User) TableUnique() [][]string { +// return [][]string{{"Email"}} +// } +type TableUniqueI interface { + TableUnique() [][]string +} + +// IsApplicableTableForDB if return false, we won't create table to this db +type IsApplicableTableForDB interface { + IsApplicableTableForDB(db string) bool +} + +// Driver define database driver +type Driver interface { + Name() string + Type() DriverType +} + +type Fielder = models.Fielder + +type TxBeginner interface { + // Begin self control transaction + Begin() (TxOrmer, error) + BeginWithCtx(ctx context.Context) (TxOrmer, error) + BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) + BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) + + // DoTx closure control transaction + DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error + DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error +} + +type TxCommitter interface { + txEnder +} + +// transaction beginner +type txer interface { + Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +// transaction ending +type txEnder interface { + Commit() error + Rollback() error + + // RollbackUnlessCommit if the transaction has been committed, do nothing, or transaction will be rollback + // For example: + // ```go + // txOrm := orm.Begin() + // defer txOrm.RollbackUnlessCommit() + // err := txOrm.Insert() // do something + // if err != nil { + // return err + // } + // txOrm.Commit() + // ``` + RollbackUnlessCommit() error +} + +// DML Data Manipulation Language +type DML interface { + // Insert insert model data to database + // for example: + // user := new(User) + // id, err = Ormer.Insert(user) + // user must be a pointer and Insert will Set user's pk field + Insert(md interface{}) (int64, error) + InsertWithCtx(ctx context.Context, md interface{}) (int64, error) + // InsertOrUpdate mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") + // if colu type is integer : can use(+-*/), string : convert(colu,"value") + // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") + // if colu type is integer : can use(+-*/), string : colu || "value" + InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) + InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) + // InsertMulti inserts some models to database + InsertMulti(bulk int, mds interface{}) (int64, error) + InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) + // Update updates model to database. + // cols Set the Columns those want to update. + // find model by Id(pk) field and update Columns specified by Fields, if cols is null then update All Columns + // for example: + // user := User{Id: 2} + // user.Langs = append(user.Langs, "zh-CN", "en-US") + // user.Extra.Name = "beego" + // user.Extra.Data = "orm" + // num, err = Ormer.Update(&user, "Langs", "Extra") + Update(md interface{}, cols ...string) (int64, error) + UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) + // Delete deletes model in database + Delete(md interface{}, cols ...string) (int64, error) + DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) + + // Raw return a raw query seter for raw sql string. + // for example: + // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() + // // update user testing's name to slene + Raw(query string, args ...interface{}) RawSeter + RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter + ExecRaw(ctx context.Context, md interface{}, query string, args ...any) (sql.Result, error) +} + +// DQL Data Query Language +type DQL interface { + + // Read reads data to model + // for example: + // this will find User by Id field + // u = &User{Id: user.Id} + // err = Ormer.Read(u) + // this will find User by UserName field + // u = &User{UserName: "astaxie", Password: "pass"} + // err = Ormer.Read(u, "UserName") + Read(md interface{}, cols ...string) error + // ReadRaw reads data to model + ReadRaw(ctx context.Context, md interface{}, query string, args ...any) error + ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error + + // ReadForUpdate Like Read(), but with "FOR UPDATE" clause, useful in transaction. + // Some databases are not support this feature. + ReadForUpdate(md interface{}, cols ...string) error + ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error + + // ReadOrCreate Try to read a row from the database, or insert one if it doesn't exist + ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) + ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) + + // LoadRelated load related models to md model. + // args are limit, offset int and order string. + // + // example: + // Ormer.LoadRelated(post,"Tags") + // for _,tag := range post.Tags{...} + // hints.DefaultRelDepth useDefaultRelsDepth ; or depth 0 + // hints.RelDepth loadRelationDepth + // hints.Limit limit default limit 1000 + // hints.Offset int offset default offset 0 + // hints.OrderBy string order for example : "-Id" + // make sure the relation is defined in model struct tags. + LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) + LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) + + // QueryM2M create a models to models queryer + // for example: + // post := Post{Id: 4} + // m2m := Ormer.QueryM2M(&post, "Tags") + QueryM2M(md interface{}, name string) QueryM2Mer + // QueryM2MWithCtx NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx + QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer + + // QueryTable return a QuerySeter for table operations. + // table name can be string or struct. + // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), + QueryTable(ptrStructOrTableName interface{}) QuerySeter + // QueryTableWithCtx NOTE: this method is deprecated, context parameter will not take effect. + // Use context.Context directly on methods with `WithCtx` suffix such as InsertWithCtx/UpdateWithCtx + QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter + + DBStats() *sql.DBStats +} + +type DriverGetter interface { + Driver() Driver +} + +type ormer interface { + DQL + DML + DriverGetter +} + +// QueryExecutor wrapping for ormer +type QueryExecutor interface { + ormer +} + +type Ormer interface { + QueryExecutor + TxBeginner +} + +type TxOrmer interface { + QueryExecutor + TxCommitter +} + +// Inserter insert prepared statement +type Inserter interface { + Insert(interface{}) (int64, error) + InsertWithCtx(context.Context, interface{}) (int64, error) + Close() error +} + +// QuerySeter query seter +type QuerySeter interface { + // Filter add condition expression to QuerySeter. + // for example: + // filter by UserName == 'slene' + // qs.Filter("UserName", "slene") + // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 + // Filter("profile__Age", 28) + // // time compare + // qs.Filter("created", time.Now()) + Filter(string, ...interface{}) QuerySeter + // FilterRaw add raw sql to querySeter. + // for example: + // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") + // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) + FilterRaw(string, string) QuerySeter + // Exclude add NOT condition to querySeter. + // have the same usage as Filter + Exclude(string, ...interface{}) QuerySeter + // SetCond Set condition to QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond1).Count() + SetCond(*Condition) QuerySeter + // GetCond Get condition from QuerySeter. + // sql's where condition + // cond := orm.NewCondition() + // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) + // qs = qs.SetCond(cond) + // cond = qs.GetCond() + // cond := cond.Or("profile__age__gt", 2000) + // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 + // num, err := qs.SetCond(cond).Count() + GetCond() *Condition + // Limit add LIMIT value. + // args[0] means offset, e.g. LIMIT num,offset. + // if Limit <= 0 then Limit will be Set to default limit ,eg 1000 + // if QuerySeter doesn't call Limit, the sql's Limit will be Set to default limit, eg 1000 + // for example: + // qs.Limit(10, 2) + // // sql-> limit 10 offset 2 + Limit(limit interface{}, args ...interface{}) QuerySeter + // Offset add OFFSET value + // same as Limit function's args[0] + Offset(offset interface{}) QuerySeter + // GroupBy add GROUP BY expression + // for example: + // qs.GroupBy("id") + GroupBy(exprs ...string) QuerySeter + // OrderBy add ORDER expression. + // "column" means ASC, "-column" means DESC. + // for example: + // qs.OrderBy("-status") + OrderBy(exprs ...string) QuerySeter + // OrderClauses add ORDER expression by order clauses + // for example: + // OrderClauses( + // order_clause.Clause( + // order.Column("Id"), + // order.SortAscending(), + // ), + // order_clause.Clause( + // order.Column("status"), + // order.SortDescending(), + // ), + // ) + // OrderClauses(order_clause.Clause( + // order_clause.Column(`user__status`), + // order_clause.SortDescending(),//default None + // )) + // OrderClauses(order_clause.Clause( + // order_clause.Column(`random()`), + // order_clause.SortNone(),//default None + // order_clause.Raw(),//default false.if true, do not check field is valid or not + // )) + OrderClauses(orders ...*order_clause.Order) QuerySeter + // ForceIndex add FORCE INDEX expression. + // for example: + // qs.ForceIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + ForceIndex(indexes ...string) QuerySeter + // UseIndex add USE INDEX expression. + // for example: + // qs.UseIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + UseIndex(indexes ...string) QuerySeter + // IgnoreIndex add IGNORE INDEX expression. + // for example: + // qs.IgnoreIndex(`idx_name1`,`idx_name2`) + // ForceIndex, UseIndex , IgnoreIndex are mutually exclusive + IgnoreIndex(indexes ...string) QuerySeter + // RelatedSel Set relation model to query together. + // it will query relation models and assign to parent model. + // for example: + // // will load All related Fields use left join . + // qs.RelatedSel().One(&user) + // // will load related field only profile + // qs.RelatedSel("profile").One(&user) + // user.Profile.Age = 32 + RelatedSel(params ...interface{}) QuerySeter + // Distinct Set Distinct + // for example: + // o.QueryTable("policy").Filter("Groups__Group__Users__User", user). + // Distinct(). + // All(&permissions) + Distinct() QuerySeter + // ForUpdate Set FOR UPDATE to query. + // for example: + // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) + ForUpdate() QuerySeter + // Count returns QuerySeter execution result number + // for example: + // num, err = qs.Filter("profile__age__gt", 28).Count() + Count() (int64, error) + CountWithCtx(context.Context) (int64, error) + // Exist check result empty or not after QuerySeter executed + // the same as QuerySeter.Count > 0 + Exist() bool + ExistWithCtx(context.Context) bool + // Update execute update with parameters + // for example: + // num, err = qs.Filter("user_name", "slene").Update(Params{ + // "Nums": ColValue(Col_Minus, 50), + // }) // user slene's Nums will minus 50 + // num, err = qs.Filter("UserName", "slene").Update(Params{ + // "user_name": "slene2" + // }) // user slene's name will change to slene2 + Update(values Params) (int64, error) + UpdateWithCtx(ctx context.Context, values Params) (int64, error) + // Delete delete from table + // for example: + // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() + // //delete two user who's name is testing1 or testing2 + Delete() (int64, error) + DeleteWithCtx(context.Context) (int64, error) + // PrepareInsert return an insert queryer. + // it can be used in times. + // example: + // i,err := sq.PrepareInsert() + // num, err = i.Insert(&user1) // user table will add one record user1 at once + // num, err = i.Insert(&user2) // user table will add one record user2 at once + // err = i.Close() //don't forget call Close + PrepareInsert() (Inserter, error) + PrepareInsertWithCtx(context.Context) (Inserter, error) + // All query All data and map to containers. + // cols means the Columns when querying. + // for example: + // var users []*User + // qs.All(&users) // users[0],users[1],users[2] ... + All(container interface{}, cols ...string) (int64, error) + AllWithCtx(ctx context.Context, container interface{}, cols ...string) (int64, error) + // One query one row data and map to containers. + // cols means the Columns when querying. + // for example: + // var user User + // qs.One(&user) //user.UserName == "slene" + One(container interface{}, cols ...string) error + OneWithCtx(ctx context.Context, container interface{}, cols ...string) error + // Values query All data and map to []map[string]interface. + // expres means condition expression. + // it converts data to []map[column]value. + // for example: + // var maps []Params + // qs.Values(&maps) //maps[0]["UserName"]=="slene" + Values(results *[]Params, exprs ...string) (int64, error) + ValuesWithCtx(ctx context.Context, results *[]Params, exprs ...string) (int64, error) + // ValuesList query All data and map to [][]interface + // it converts data to [][column_index]value + // for example: + // var list []ParamsList + // qs.ValuesList(&list) // list[0][1] == "slene" + ValuesList(results *[]ParamsList, exprs ...string) (int64, error) + ValuesListWithCtx(ctx context.Context, results *[]ParamsList, exprs ...string) (int64, error) + // ValuesFlat query All data and map to []interface. + // it's designed for one column record Set, auto change to []value, not [][column]value. + // for example: + // var list ParamsList + // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" + ValuesFlat(result *ParamsList, expr string) (int64, error) + ValuesFlatWithCtx(ctx context.Context, result *ParamsList, expr string) (int64, error) + // RowsToMap query All rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // RowsToStruct query All rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + // Aggregate aggregate func. + // for example: + // type result struct { + // DeptName string + // Total int + // } + // var res []result + // o.QueryTable("dept_info").Aggregate("dept_name,sum(salary) as total").GroupBy("dept_name").All(&res) + Aggregate(s string) QuerySeter +} + +// QueryM2Mer model to model query struct +// All operations are on the m2m table only, will not affect the origin model table +type QueryM2Mer interface { + // Add adds models to origin models when creating queryM2M. + // example: + // m2m := orm.QueryM2M(post,"Tag") + // m2m.Add(&Tag1{},&Tag2{}) + // for _,tag := range post.Tags{}{ ... } + // param could also be any of the follow + // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} + // &Tag{Id:5,Name: "TestTag3"} + // []interface{}{&Tag{Id:6,Name: "TestTag4"}} + // insert one or more rows to m2m table + // make sure the relation is defined in post model struct tag. + Add(...interface{}) (int64, error) + AddWithCtx(context.Context, ...interface{}) (int64, error) + // Remove removes models following the origin model relationship + // only delete rows from m2m table + // for example: + // tag3 := &Tag{Id:5,Name: "TestTag3"} + // num, err = m2m.Remove(tag3) + Remove(...interface{}) (int64, error) + RemoveWithCtx(context.Context, ...interface{}) (int64, error) + // Exist checks model is existed in relationship of origin model + Exist(interface{}) bool + ExistWithCtx(context.Context, interface{}) bool + // Clear cleans All models in related of origin model + Clear() (int64, error) + ClearWithCtx(context.Context) (int64, error) + // Count counts All related models of origin model + Count() (int64, error) + CountWithCtx(context.Context) (int64, error) +} + +// RawPreparer raw query statement +type RawPreparer interface { + Exec(...interface{}) (sql.Result, error) + Close() error +} + +// RawSeter raw query seter +// create From Ormer.Raw +// for example: +// +// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) +// rs := Ormer.Raw(sql, 1) +type RawSeter interface { + // Exec execute sql and Get result + Exec() (sql.Result, error) + // QueryRow query data and map to container + // for example: + // var name string + // var id int + // rs.QueryRow(&id,&name) // id==2 name=="slene" + QueryRow(containers ...interface{}) error + + // QueryRows query data rows and map to container + // var ids []int + // var names []int + // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) + // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} + QueryRows(containers ...interface{}) (int64, error) + SetArgs(...interface{}) RawSeter + // Values query data to []map[string]interface + // see QuerySeter's Values + Values(container *[]Params, cols ...string) (int64, error) + // ValuesList query data to [][]interface + // see QuerySeter's ValuesList + ValuesList(container *[]ParamsList, cols ...string) (int64, error) + // ValuesFlat query data to []interface + // see QuerySeter's ValuesFlat + ValuesFlat(container *ParamsList, cols ...string) (int64, error) + // RowsToMap query All rows into map[string]interface with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to map[string]interface{}{ + // "total": 100, + // "found": 200, + // } + RowsToMap(result *Params, keyCol, valueCol string) (int64, error) + // RowsToStruct query All rows into struct with specify key and value column name. + // keyCol = "name", valueCol = "value" + // table data + // name | value + // total | 100 + // found | 200 + // to struct { + // Total int + // Found int + // } + RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) + + // Prepare return prepared raw statement for used in times. + // for example: + // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() + // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) + Prepare() (RawPreparer, error) +} + +// stmtQuerier statement querier +type stmtQuerier interface { + Close() error + Exec(args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) + Query(args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) + QueryRow(args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row +} + +// db querier +type dbQuerier interface { + Prepare(query string) (*sql.Stmt, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row +} + +// type DB interface { +// Begin() (*sql.Tx, error) +// Prepare(query string) (stmtQuerier, error) +// Exec(query string, args ...interface{}) (sql.Result, error) +// Query(query string, args ...interface{}) (*sql.Rows, error) +// QueryRow(query string, args ...interface{}) *sql.Row +// } + +// base database struct +type dbBaser interface { + Read(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string, bool) error + ReadRaw(ctx context.Context, q dbQuerier, mi *models.ModelInfo, ind reflect.Value, tz *time.Location, query string, args ...any) error + ReadBatch(context.Context, dbQuerier, querySet, *models.ModelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) + Count(context.Context, dbQuerier, querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error) + ReadValues(context.Context, dbQuerier, querySet, *models.ModelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) + + ExecRaw(ctx context.Context, q dbQuerier, query string, args ...any) (sql.Result, error) + Insert(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error) + InsertOrUpdate(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *DB, ...string) (int64, error) + InsertMulti(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, int, *time.Location) (int64, error) + InsertValue(context.Context, dbQuerier, *models.ModelInfo, bool, []string, []interface{}) (int64, error) + InsertStmt(context.Context, stmtQuerier, *models.ModelInfo, reflect.Value, *time.Location) (int64, error) + + Update(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string) (int64, error) + UpdateBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, Params, *time.Location) (int64, error) + + Delete(context.Context, dbQuerier, *models.ModelInfo, reflect.Value, *time.Location, []string) (int64, error) + DeleteBatch(context.Context, dbQuerier, *querySet, *models.ModelInfo, *Condition, *time.Location) (int64, error) + + SupportUpdateJoin() bool + OperatorSQL(string) string + GenerateOperatorSQL(*models.ModelInfo, *models.FieldInfo, string, []interface{}, *time.Location) (string, []interface{}) + GenerateOperatorLeftCol(*models.FieldInfo, string, *string) + PrepareInsert(context.Context, dbQuerier, *models.ModelInfo) (stmtQuerier, string, error) + MaxLimit() uint64 + TableQuote() string + ReplaceMarks(*string) + HasReturningID(*models.ModelInfo, *string) bool + TimeFromDB(*time.Time, *time.Location) + TimeToDB(*time.Time, *time.Location) + DbTypes() map[string]string + GetTables(dbQuerier) (map[string]bool, error) + GetColumns(context.Context, dbQuerier, string) (map[string][3]string, error) + ShowTablesQuery() string + ShowColumnsQuery(string) string + IndexExists(context.Context, dbQuerier, string, string) bool + collectFieldValue(*models.ModelInfo, *models.FieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) + setval(context.Context, dbQuerier, *models.ModelInfo, []string) error + + GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string +} diff --git a/client/orm/utils.go b/client/orm/utils.go new file mode 100644 index 0000000000..40a818d6d6 --- /dev/null +++ b/client/orm/utils.go @@ -0,0 +1,29 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package orm + +import ( + "github.com/beego/beego/v2/client/orm/internal/models" + "github.com/beego/beego/v2/client/orm/internal/utils" +) + +type StrTo = utils.StrTo + +func SetNameStrategy(s string) { + if models.SnakeAcronymNameStrategy != s { + models.NameStrategy = models.DefaultNameStrategy + } + models.NameStrategy = s +} diff --git a/config.go b/config.go deleted file mode 100644 index 7969dcea51..0000000000 --- a/config.go +++ /dev/null @@ -1,510 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "fmt" - "os" - "path/filepath" - "reflect" - "runtime" - "strings" - - "github.com/astaxie/beego/config" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/session" - "github.com/astaxie/beego/utils" -) - -// Config is the main struct for BConfig -type Config struct { - AppName string //Application name - RunMode string //Running Mode: dev | prod - RouterCaseSensitive bool - ServerName string - RecoverPanic bool - RecoverFunc func(*context.Context) - CopyRequestBody bool - EnableGzip bool - MaxMemory int64 - EnableErrorsShow bool - EnableErrorsRender bool - Listen Listen - WebConfig WebConfig - Log LogConfig -} - -// Listen holds for http and https related config -type Listen struct { - Graceful bool // Graceful means use graceful module to start the server - ServerTimeOut int64 - ListenTCP4 bool - EnableHTTP bool - HTTPAddr string - HTTPPort int - AutoTLS bool - Domains []string - TLSCacheDir string - EnableHTTPS bool - EnableMutualHTTPS bool - HTTPSAddr string - HTTPSPort int - HTTPSCertFile string - HTTPSKeyFile string - TrustCaFile string - EnableAdmin bool - AdminAddr string - AdminPort int - EnableFcgi bool - EnableStdIo bool // EnableStdIo works with EnableFcgi Use FCGI via standard I/O -} - -// WebConfig holds web related config -type WebConfig struct { - AutoRender bool - EnableDocs bool - FlashName string - FlashSeparator string - DirectoryIndex bool - StaticDir map[string]string - StaticExtensionsToGzip []string - TemplateLeft string - TemplateRight string - ViewsPath string - EnableXSRF bool - XSRFKey string - XSRFExpire int - Session SessionConfig -} - -// SessionConfig holds session related config -type SessionConfig struct { - SessionOn bool - SessionProvider string - SessionName string - SessionGCMaxLifetime int64 - SessionProviderConfig string - SessionCookieLifeTime int - SessionAutoSetCookie bool - SessionDomain string - SessionDisableHTTPOnly bool // used to allow for cross domain cookies/javascript cookies. - SessionEnableSidInHTTPHeader bool // enable store/get the sessionId into/from http headers - SessionNameInHTTPHeader string - SessionEnableSidInURLQuery bool // enable get the sessionId from Url Query params -} - -// LogConfig holds Log related config -type LogConfig struct { - AccessLogs bool - EnableStaticLogs bool //log static files requests default: false - AccessLogsFormat string //access log format: JSON_FORMAT, APACHE_FORMAT or empty string - FileLineNum bool - Outputs map[string]string // Store Adaptor : config -} - -var ( - // BConfig is the default config for Application - BConfig *Config - // AppConfig is the instance of Config, store the config information from file - AppConfig *beegoAppConfig - // AppPath is the absolute path to the app - AppPath string - // GlobalSessions is the instance for the session manager - GlobalSessions *session.Manager - - // appConfigPath is the path to the config files - appConfigPath string - // appConfigProvider is the provider for the config, default is ini - appConfigProvider = "ini" -) - -func init() { - BConfig = newBConfig() - var err error - if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { - panic(err) - } - workPath, err := os.Getwd() - if err != nil { - panic(err) - } - var filename = "app.conf" - if os.Getenv("BEEGO_RUNMODE") != "" { - filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" - } - appConfigPath = filepath.Join(workPath, "conf", filename) - if !utils.FileExists(appConfigPath) { - appConfigPath = filepath.Join(AppPath, "conf", filename) - if !utils.FileExists(appConfigPath) { - AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} - return - } - } - if err = parseConfig(appConfigPath); err != nil { - panic(err) - } -} - -func recoverPanic(ctx *context.Context) { - if err := recover(); err != nil { - if err == ErrAbort { - return - } - if !BConfig.RecoverPanic { - panic(err) - } - if BConfig.EnableErrorsShow { - if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { - exception(fmt.Sprint(err), ctx) - return - } - } - var stack string - logs.Critical("the request url is ", ctx.Input.URL()) - logs.Critical("Handler crashed with error", err) - for i := 1; ; i++ { - _, file, line, ok := runtime.Caller(i) - if !ok { - break - } - logs.Critical(fmt.Sprintf("%s:%d", file, line)) - stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) - } - if BConfig.RunMode == DEV && BConfig.EnableErrorsRender { - showErr(err, ctx, stack) - } - if ctx.Output.Status != 0 { - ctx.ResponseWriter.WriteHeader(ctx.Output.Status) - } else { - ctx.ResponseWriter.WriteHeader(500) - } - } -} - -func newBConfig() *Config { - return &Config{ - AppName: "beego", - RunMode: PROD, - RouterCaseSensitive: true, - ServerName: "beegoServer:" + VERSION, - RecoverPanic: true, - RecoverFunc: recoverPanic, - CopyRequestBody: false, - EnableGzip: false, - MaxMemory: 1 << 26, //64MB - EnableErrorsShow: true, - EnableErrorsRender: true, - Listen: Listen{ - Graceful: false, - ServerTimeOut: 0, - ListenTCP4: false, - EnableHTTP: true, - AutoTLS: false, - Domains: []string{}, - TLSCacheDir: ".", - HTTPAddr: "", - HTTPPort: 8080, - EnableHTTPS: false, - HTTPSAddr: "", - HTTPSPort: 10443, - HTTPSCertFile: "", - HTTPSKeyFile: "", - EnableAdmin: false, - AdminAddr: "", - AdminPort: 8088, - EnableFcgi: false, - EnableStdIo: false, - }, - WebConfig: WebConfig{ - AutoRender: true, - EnableDocs: false, - FlashName: "BEEGO_FLASH", - FlashSeparator: "BEEGOFLASH", - DirectoryIndex: false, - StaticDir: map[string]string{"/static": "static"}, - StaticExtensionsToGzip: []string{".css", ".js"}, - TemplateLeft: "{{", - TemplateRight: "}}", - ViewsPath: "views", - EnableXSRF: false, - XSRFKey: "beegoxsrf", - XSRFExpire: 0, - Session: SessionConfig{ - SessionOn: false, - SessionProvider: "memory", - SessionName: "beegosessionID", - SessionGCMaxLifetime: 3600, - SessionProviderConfig: "", - SessionDisableHTTPOnly: false, - SessionCookieLifeTime: 0, //set cookie default is the browser life - SessionAutoSetCookie: true, - SessionDomain: "", - SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers - SessionNameInHTTPHeader: "Beegosessionid", - SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params - }, - }, - Log: LogConfig{ - AccessLogs: false, - EnableStaticLogs: false, - AccessLogsFormat: "APACHE_FORMAT", - FileLineNum: true, - Outputs: map[string]string{"console": ""}, - }, - } -} - -// now only support ini, next will support json. -func parseConfig(appConfigPath string) (err error) { - AppConfig, err = newAppConfig(appConfigProvider, appConfigPath) - if err != nil { - return err - } - return assignConfig(AppConfig) -} - -func assignConfig(ac config.Configer) error { - for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { - assignSingleConfig(i, ac) - } - // set the run mode first - if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { - BConfig.RunMode = envRunMode - } else if runMode := ac.String("RunMode"); runMode != "" { - BConfig.RunMode = runMode - } - - if sd := ac.String("StaticDir"); sd != "" { - BConfig.WebConfig.StaticDir = map[string]string{} - sds := strings.Fields(sd) - for _, v := range sds { - if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { - BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1] - } else { - BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0] - } - } - } - - if sgz := ac.String("StaticExtensionsToGzip"); sgz != "" { - extensions := strings.Split(sgz, ",") - fileExts := []string{} - for _, ext := range extensions { - ext = strings.TrimSpace(ext) - if ext == "" { - continue - } - if !strings.HasPrefix(ext, ".") { - ext = "." + ext - } - fileExts = append(fileExts, ext) - } - if len(fileExts) > 0 { - BConfig.WebConfig.StaticExtensionsToGzip = fileExts - } - } - - if lo := ac.String("LogOutputs"); lo != "" { - // if lo is not nil or empty - // means user has set his own LogOutputs - // clear the default setting to BConfig.Log.Outputs - BConfig.Log.Outputs = make(map[string]string) - los := strings.Split(lo, ";") - for _, v := range los { - if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { - BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1] - } else { - continue - } - } - } - - //init log - logs.Reset() - for adaptor, config := range BConfig.Log.Outputs { - err := logs.SetLogger(adaptor, config) - if err != nil { - fmt.Fprintln(os.Stderr, fmt.Sprintf("%s with the config %q got err:%s", adaptor, config, err.Error())) - } - } - logs.SetLogFuncCall(BConfig.Log.FileLineNum) - - return nil -} - -func assignSingleConfig(p interface{}, ac config.Configer) { - pt := reflect.TypeOf(p) - if pt.Kind() != reflect.Ptr { - return - } - pt = pt.Elem() - if pt.Kind() != reflect.Struct { - return - } - pv := reflect.ValueOf(p).Elem() - - for i := 0; i < pt.NumField(); i++ { - pf := pv.Field(i) - if !pf.CanSet() { - continue - } - name := pt.Field(i).Name - switch pf.Kind() { - case reflect.String: - pf.SetString(ac.DefaultString(name, pf.String())) - case reflect.Int, reflect.Int64: - pf.SetInt(ac.DefaultInt64(name, pf.Int())) - case reflect.Bool: - pf.SetBool(ac.DefaultBool(name, pf.Bool())) - case reflect.Struct: - default: - //do nothing here - } - } - -} - -// LoadAppConfig allow developer to apply a config file -func LoadAppConfig(adapterName, configPath string) error { - absConfigPath, err := filepath.Abs(configPath) - if err != nil { - return err - } - - if !utils.FileExists(absConfigPath) { - return fmt.Errorf("the target config file: %s don't exist", configPath) - } - - appConfigPath = absConfigPath - appConfigProvider = adapterName - - return parseConfig(appConfigPath) -} - -type beegoAppConfig struct { - innerConfig config.Configer -} - -func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) { - ac, err := config.NewConfig(appConfigProvider, appConfigPath) - if err != nil { - return nil, err - } - return &beegoAppConfig{ac}, nil -} - -func (b *beegoAppConfig) Set(key, val string) error { - if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { - return err - } - return b.innerConfig.Set(key, val) -} - -func (b *beegoAppConfig) String(key string) string { - if v := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" { - return v - } - return b.innerConfig.String(key) -} - -func (b *beegoAppConfig) Strings(key string) []string { - if v := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 { - return v - } - return b.innerConfig.Strings(key) -} - -func (b *beegoAppConfig) Int(key string) (int, error) { - if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Int(key) -} - -func (b *beegoAppConfig) Int64(key string) (int64, error) { - if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Int64(key) -} - -func (b *beegoAppConfig) Bool(key string) (bool, error) { - if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Bool(key) -} - -func (b *beegoAppConfig) Float(key string) (float64, error) { - if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { - return v, nil - } - return b.innerConfig.Float(key) -} - -func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { - if v := b.String(key); v != "" { - return v - } - return defaultVal -} - -func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { - if v := b.Strings(key); len(v) != 0 { - return v - } - return defaultVal -} - -func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { - if v, err := b.Int(key); err == nil { - return v - } - return defaultVal -} - -func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { - if v, err := b.Int64(key); err == nil { - return v - } - return defaultVal -} - -func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { - if v, err := b.Bool(key); err == nil { - return v - } - return defaultVal -} - -func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { - if v, err := b.Float(key); err == nil { - return v - } - return defaultVal -} - -func (b *beegoAppConfig) DIY(key string) (interface{}, error) { - return b.innerConfig.DIY(key) -} - -func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { - return b.innerConfig.GetSection(section) -} - -func (b *beegoAppConfig) SaveConfigFile(filename string) error { - return b.innerConfig.SaveConfigFile(filename) -} diff --git a/config/env/env.go b/config/env/env.go deleted file mode 100644 index 34f094febf..0000000000 --- a/config/env/env.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// Copyright 2017 Faissal Elamraoui. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package env is used to parse environment. -package env - -import ( - "fmt" - "os" - "strings" - - "github.com/astaxie/beego/utils" -) - -var env *utils.BeeMap - -func init() { - env = utils.NewBeeMap() - for _, e := range os.Environ() { - splits := strings.Split(e, "=") - env.Set(splits[0], os.Getenv(splits[0])) - } -} - -// Get returns a value by key. -// If the key does not exist, the default value will be returned. -func Get(key string, defVal string) string { - if val := env.Get(key); val != nil { - return val.(string) - } - return defVal -} - -// MustGet returns a value by key. -// If the key does not exist, it will return an error. -func MustGet(key string) (string, error) { - if val := env.Get(key); val != nil { - return val.(string), nil - } - return "", fmt.Errorf("no env variable with %s", key) -} - -// Set sets a value in the ENV copy. -// This does not affect the child process environment. -func Set(key string, value string) { - env.Set(key, value) -} - -// MustSet sets a value in the ENV copy and the child process environment. -// It returns an error in case the set operation failed. -func MustSet(key string, value string) error { - err := os.Setenv(key, value) - if err != nil { - return err - } - env.Set(key, value) - return nil -} - -// GetAll returns all keys/values in the current child process environment. -func GetAll() map[string]string { - items := env.Items() - envs := make(map[string]string, env.Count()) - - for key, val := range items { - switch key := key.(type) { - case string: - switch val := val.(type) { - case string: - envs[key] = val - } - } - } - return envs -} diff --git a/config/yaml/yaml_test.go b/config/yaml/yaml_test.go deleted file mode 100644 index 49cc1d1e7f..0000000000 --- a/config/yaml/yaml_test.go +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package yaml - -import ( - "fmt" - "os" - "testing" - - "github.com/astaxie/beego/config" -) - -func TestYaml(t *testing.T) { - - var ( - yamlcontext = ` -"appname": beeapi -"httpport": 8080 -"mysqlport": 3600 -"PI": 3.1415976 -"runmode": dev -"autorender": false -"copyrequestbody": true -"PATH": GOPATH -"path1": ${GOPATH} -"path2": ${GOPATH||/home/go} -"empty": "" -` - - keyValue = map[string]interface{}{ - "appname": "beeapi", - "httpport": 8080, - "mysqlport": int64(3600), - "PI": 3.1415976, - "runmode": "dev", - "autorender": false, - "copyrequestbody": true, - "PATH": "GOPATH", - "path1": os.Getenv("GOPATH"), - "path2": os.Getenv("GOPATH"), - "error": "", - "emptystrings": []string{}, - } - ) - f, err := os.Create("testyaml.conf") - if err != nil { - t.Fatal(err) - } - _, err = f.WriteString(yamlcontext) - if err != nil { - f.Close() - t.Fatal(err) - } - f.Close() - defer os.Remove("testyaml.conf") - yamlconf, err := config.NewConfig("yaml", "testyaml.conf") - if err != nil { - t.Fatal(err) - } - - if yamlconf.String("appname") != "beeapi" { - t.Fatal("appname not equal to beeapi") - } - - for k, v := range keyValue { - - var ( - value interface{} - err error - ) - - switch v.(type) { - case int: - value, err = yamlconf.Int(k) - case int64: - value, err = yamlconf.Int64(k) - case float64: - value, err = yamlconf.Float(k) - case bool: - value, err = yamlconf.Bool(k) - case []string: - value = yamlconf.Strings(k) - case string: - value = yamlconf.String(k) - default: - value, err = yamlconf.DIY(k) - } - if err != nil { - t.Errorf("get key %q value fatal,%v err %s", k, v, err) - } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { - t.Errorf("get key %q value, want %v got %v .", k, v, value) - } - - } - - if err = yamlconf.Set("name", "astaxie"); err != nil { - t.Fatal(err) - } - if yamlconf.String("name") != "astaxie" { - t.Fatal("get name error") - } - -} diff --git a/context/context.go b/context/context.go deleted file mode 100644 index bbd58299f6..0000000000 --- a/context/context.go +++ /dev/null @@ -1,263 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package context provide the context utils -// Usage: -// -// import "github.com/astaxie/beego/context" -// -// ctx := context.Context{Request:req,ResponseWriter:rw} -// -// more docs http://beego.me/docs/module/context.md -package context - -import ( - "bufio" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "errors" - "fmt" - "net" - "net/http" - "strconv" - "strings" - "time" - - "github.com/astaxie/beego/utils" -) - -//commonly used mime-types -const ( - ApplicationJSON = "application/json" - ApplicationXML = "application/xml" - ApplicationYAML = "application/x-yaml" - TextXML = "text/xml" -) - -// NewContext return the Context with Input and Output -func NewContext() *Context { - return &Context{ - Input: NewInput(), - Output: NewOutput(), - } -} - -// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. -// BeegoInput and BeegoOutput provides some api to operate request and response more easily. -type Context struct { - Input *BeegoInput - Output *BeegoOutput - Request *http.Request - ResponseWriter *Response - _xsrfToken string -} - -// Reset init Context, BeegoInput and BeegoOutput -func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { - ctx.Request = r - if ctx.ResponseWriter == nil { - ctx.ResponseWriter = &Response{} - } - ctx.ResponseWriter.reset(rw) - ctx.Input.Reset(ctx) - ctx.Output.Reset(ctx) - ctx._xsrfToken = "" -} - -// Redirect does redirection to localurl with http header status code. -func (ctx *Context) Redirect(status int, localurl string) { - http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) -} - -// Abort stops this request. -// if beego.ErrorMaps exists, panic body. -func (ctx *Context) Abort(status int, body string) { - ctx.Output.SetStatus(status) - panic(body) -} - -// WriteString Write string to response body. -// it sends response body. -func (ctx *Context) WriteString(content string) { - ctx.ResponseWriter.Write([]byte(content)) -} - -// GetCookie Get cookie from request by a given key. -// It's alias of BeegoInput.Cookie. -func (ctx *Context) GetCookie(key string) string { - return ctx.Input.Cookie(key) -} - -// SetCookie Set cookie for response. -// It's alias of BeegoOutput.Cookie. -func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { - ctx.Output.Cookie(name, value, others...) -} - -// GetSecureCookie Get secure cookie from request by a given key. -func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { - val := ctx.Input.Cookie(key) - if val == "" { - return "", false - } - - parts := strings.SplitN(val, "|", 3) - - if len(parts) != 3 { - return "", false - } - - vs := parts[0] - timestamp := parts[1] - sig := parts[2] - - h := hmac.New(sha1.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) - - if fmt.Sprintf("%02x", h.Sum(nil)) != sig { - return "", false - } - res, _ := base64.URLEncoding.DecodeString(vs) - return string(res), true -} - -// SetSecureCookie Set Secure cookie for response. -func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { - vs := base64.URLEncoding.EncodeToString([]byte(value)) - timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) - h := hmac.New(sha1.New, []byte(Secret)) - fmt.Fprintf(h, "%s%s", vs, timestamp) - sig := fmt.Sprintf("%02x", h.Sum(nil)) - cookie := strings.Join([]string{vs, timestamp, sig}, "|") - ctx.Output.Cookie(name, cookie, others...) -} - -// XSRFToken creates a xsrf token string and returns. -func (ctx *Context) XSRFToken(key string, expire int64) string { - if ctx._xsrfToken == "" { - token, ok := ctx.GetSecureCookie(key, "_xsrf") - if !ok { - token = string(utils.RandomCreateBytes(32)) - ctx.SetSecureCookie(key, "_xsrf", token, expire) - } - ctx._xsrfToken = token - } - return ctx._xsrfToken -} - -// CheckXSRFCookie checks xsrf token in this request is valid or not. -// the token can provided in request header "X-Xsrftoken" and "X-CsrfToken" -// or in form field value named as "_xsrf". -func (ctx *Context) CheckXSRFCookie() bool { - token := ctx.Input.Query("_xsrf") - if token == "" { - token = ctx.Request.Header.Get("X-Xsrftoken") - } - if token == "" { - token = ctx.Request.Header.Get("X-Csrftoken") - } - if token == "" { - ctx.Abort(403, "'_xsrf' argument missing from POST") - return false - } - if ctx._xsrfToken != token { - ctx.Abort(403, "XSRF cookie does not match POST argument") - return false - } - return true -} - -// RenderMethodResult renders the return value of a controller method to the output -func (ctx *Context) RenderMethodResult(result interface{}) { - if result != nil { - renderer, ok := result.(Renderer) - if !ok { - err, ok := result.(error) - if ok { - renderer = errorRenderer(err) - } else { - renderer = jsonRenderer(result) - } - } - renderer.Render(ctx) - } -} - -//Response is a wrapper for the http.ResponseWriter -//started set to true if response was written to then don't execute other handler -type Response struct { - http.ResponseWriter - Started bool - Status int - Elapsed time.Duration -} - -func (r *Response) reset(rw http.ResponseWriter) { - r.ResponseWriter = rw - r.Status = 0 - r.Started = false -} - -// Write writes the data to the connection as part of an HTTP reply, -// and sets `started` to true. -// started means the response has sent out. -func (r *Response) Write(p []byte) (int, error) { - r.Started = true - return r.ResponseWriter.Write(p) -} - -// WriteHeader sends an HTTP response header with status code, -// and sets `started` to true. -func (r *Response) WriteHeader(code int) { - if r.Status > 0 { - //prevent multiple response.WriteHeader calls - return - } - r.Status = code - r.Started = true - r.ResponseWriter.WriteHeader(code) -} - -// Hijack hijacker for http -func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hj, ok := r.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, errors.New("webserver doesn't support hijacking") - } - return hj.Hijack() -} - -// Flush http.Flusher -func (r *Response) Flush() { - if f, ok := r.ResponseWriter.(http.Flusher); ok { - f.Flush() - } -} - -// CloseNotify http.CloseNotifier -func (r *Response) CloseNotify() <-chan bool { - if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { - return cn.CloseNotify() - } - return nil -} - -// Pusher http.Pusher -func (r *Response) Pusher() (pusher http.Pusher) { - if pusher, ok := r.ResponseWriter.(http.Pusher); ok { - return pusher - } - return nil -} diff --git a/context/context_test.go b/context/context_test.go deleted file mode 100644 index 7c0535e0ae..0000000000 --- a/context/context_test.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2016 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package context - -import ( - "net/http" - "net/http/httptest" - "testing" -) - -func TestXsrfReset_01(t *testing.T) { - r := &http.Request{} - c := NewContext() - c.Request = r - c.ResponseWriter = &Response{} - c.ResponseWriter.reset(httptest.NewRecorder()) - c.Output.Reset(c) - c.Input.Reset(c) - c.XSRFToken("key", 16) - if c._xsrfToken == "" { - t.FailNow() - } - token := c._xsrfToken - c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) - if c._xsrfToken != "" { - t.FailNow() - } - c.XSRFToken("key", 16) - if c._xsrfToken == "" { - t.FailNow() - } - if token == c._xsrfToken { - t.FailNow() - } -} diff --git a/controller_test.go b/controller_test.go deleted file mode 100644 index 1e53416d7c..0000000000 --- a/controller_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "math" - "strconv" - "testing" - - "github.com/astaxie/beego/context" - "os" - "path/filepath" -) - -func TestGetInt(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt("age") - if val != 40 { - t.Errorf("TestGetInt expect 40,get %T,%v", val, val) - } -} - -func TestGetInt8(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt8("age") - if val != 40 { - t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) - } - //Output: int8 -} - -func TestGetInt16(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt16("age") - if val != 40 { - t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) - } -} - -func TestGetInt32(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt32("age") - if val != 40 { - t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) - } -} - -func TestGetInt64(t *testing.T) { - i := context.NewInput() - i.SetParam("age", "40") - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetInt64("age") - if val != 40 { - t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) - } -} - -func TestGetUint8(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint8("age") - if val != math.MaxUint8 { - t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) - } -} - -func TestGetUint16(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint16("age") - if val != math.MaxUint16 { - t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) - } -} - -func TestGetUint32(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint32("age") - if val != math.MaxUint32 { - t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) - } -} - -func TestGetUint64(t *testing.T) { - i := context.NewInput() - i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) - ctx := &context.Context{Input: i} - ctrlr := Controller{Ctx: ctx} - val, _ := ctrlr.GetUint64("age") - if val != math.MaxUint64 { - t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) - } -} - -func TestAdditionalViewPaths(t *testing.T) { - dir1 := "_beeTmp" - dir2 := "_beeTmp2" - defer os.RemoveAll(dir1) - defer os.RemoveAll(dir2) - - dir1file := "file1.tpl" - dir2file := "file2.tpl" - - genFile := func(dir string, name string, content string) { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) - if f, err := os.Create(filepath.Join(dir, name)); err != nil { - t.Fatal(err) - } else { - defer f.Close() - f.WriteString(content) - f.Close() - } - - } - genFile(dir1, dir1file, `
{{.Content}}
`) - genFile(dir2, dir2file, `{{.Content}}`) - - AddViewPath(dir1) - AddViewPath(dir2) - - ctrl := Controller{ - TplName: "file1.tpl", - ViewPath: dir1, - } - ctrl.Data = map[interface{}]interface{}{ - "Content": "value2", - } - if result, err := ctrl.RenderString(); err != nil { - t.Fatal(err) - } else { - if result != "
value2
" { - t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) - } - } - - func() { - ctrl.TplName = "file2.tpl" - defer func() { - if r := recover(); r == nil { - t.Fatal("TestAdditionalViewPaths expected error") - } - }() - ctrl.RenderString() - }() - - ctrl.TplName = "file2.tpl" - ctrl.ViewPath = dir2 - ctrl.RenderString() -} diff --git a/core/admin/command.go b/core/admin/command.go new file mode 100644 index 0000000000..c3bc1d2e36 --- /dev/null +++ b/core/admin/command.go @@ -0,0 +1,87 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package admin + +import ( + "errors" +) + +// Command is an experimental interface +// We try to use this to decouple modules +// All other modules depends on this, and they register the command they support +// We may change the API in the future, so be careful about this. +type Command interface { + Execute(params ...interface{}) *Result +} + +var CommandNotFound = errors.New("Command not found") + +type Result struct { + // Status is the same as http.Status + Status int + Error error + Content interface{} +} + +func (r *Result) IsSuccess() bool { + return r.Status >= 200 && r.Status < 300 +} + +// CommandRegistry stores all commands +// name => command +type moduleCommands map[string]Command + +// Get returns command with the name +func (m moduleCommands) Get(name string) Command { + c, ok := m[name] + if ok { + return c + } + return &doNothingCommand{} +} + +// module name => moduleCommand +type commandRegistry map[string]moduleCommands + +// Get returns module's commands +func (c commandRegistry) Get(moduleName string) moduleCommands { + if mcs, ok := c[moduleName]; ok { + return mcs + } + res := make(moduleCommands) + c[moduleName] = res + return res +} + +var cmdRegistry = make(commandRegistry) + +// RegisterCommand is not thread-safe +// do not use it in concurrent case +func RegisterCommand(module string, commandName string, command Command) { + cmdRegistry.Get(module)[commandName] = command +} + +func GetCommand(module string, cmdName string) Command { + return cmdRegistry.Get(module).Get(cmdName) +} + +type doNothingCommand struct{} + +func (d *doNothingCommand) Execute(params ...interface{}) *Result { + return &Result{ + Status: 404, + Error: CommandNotFound, + } +} diff --git a/toolbox/healthcheck.go b/core/admin/healthcheck.go similarity index 80% rename from toolbox/healthcheck.go rename to core/admin/healthcheck.go index e3544b3ad4..077b1c5307 100644 --- a/toolbox/healthcheck.go +++ b/core/admin/healthcheck.go @@ -12,23 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package toolbox healthcheck +// Package admin healthcheck // // type DatabaseCheck struct { // } // -// func (dc *DatabaseCheck) Check() error { -// if dc.isConnected() { -// return nil -// } else { -// return errors.New("can't connect database") -// } -// } +// func (dc *DatabaseCheck) Check() error { +// if dc.isConnected() { +// return nil +// } else { +// return errors.New("can't connect database") +// } +// } // // AddHealthCheck("database",&DatabaseCheck{}) -// -// more docs: http://beego.me/docs/module/toolbox.md -package toolbox +package admin // AdminCheckList holds health checker map var AdminCheckList map[string]HealthChecker diff --git a/toolbox/profile.go b/core/admin/profile.go similarity index 79% rename from toolbox/profile.go rename to core/admin/profile.go index 06e40ede73..f85afaa2c7 100644 --- a/toolbox/profile.go +++ b/core/admin/profile.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package toolbox +package admin import ( "fmt" @@ -25,10 +25,14 @@ import ( "runtime/pprof" "strconv" "time" + + "github.com/beego/beego/v2/core/utils" ) -var startTime = time.Now() -var pid int +var ( + startTime = time.Now() + pid int +) func init() { pid = os.Getpid() @@ -103,27 +107,26 @@ func PrintGCSummary(w io.Writer) { } func printGC(memStats *runtime.MemStats, gcstats *debug.GCStats, w io.Writer) { - if gcstats.NumGC > 0 { lastPause := gcstats.Pause[0] - elapsed := time.Now().Sub(startTime) + elapsed := time.Since(startTime) overhead := float64(gcstats.PauseTotal) / float64(elapsed) * 100 allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() fmt.Fprintf(w, "NumGC:%d Pause:%s Pause(Avg):%s Overhead:%3.2f%% Alloc:%s Sys:%s Alloc(Rate):%s/s Histogram:%s %s %s \n", gcstats.NumGC, - toS(lastPause), - toS(avg(gcstats.Pause)), + utils.ToShortTimeFormat(lastPause), + utils.ToShortTimeFormat(avg(gcstats.Pause)), overhead, toH(memStats.Alloc), toH(memStats.Sys), toH(uint64(allocatedRate)), - toS(gcstats.PauseQuantiles[94]), - toS(gcstats.PauseQuantiles[98]), - toS(gcstats.PauseQuantiles[99])) + utils.ToShortTimeFormat(gcstats.PauseQuantiles[94]), + utils.ToShortTimeFormat(gcstats.PauseQuantiles[98]), + utils.ToShortTimeFormat(gcstats.PauseQuantiles[99])) } else { // while GC has disabled - elapsed := time.Now().Sub(startTime) + elapsed := time.Since(startTime) allocatedRate := float64(memStats.TotalAlloc) / elapsed.Seconds() fmt.Fprintf(w, "Alloc:%s Sys:%s Alloc(Rate):%s/s\n", @@ -141,7 +144,7 @@ func avg(items []time.Duration) time.Duration { return time.Duration(int64(sum) / int64(len(items))) } -// format bytes number friendly +// toH format bytes number friendly func toH(bytes uint64) string { switch { case bytes < 1024: @@ -154,31 +157,3 @@ func toH(bytes uint64) string { return fmt.Sprintf("%.2fG", float64(bytes)/1024/1024/1024) } } - -// short string format -func toS(d time.Duration) string { - - u := uint64(d) - if u < uint64(time.Second) { - switch { - case u == 0: - return "0" - case u < uint64(time.Microsecond): - return fmt.Sprintf("%.2fns", float64(u)) - case u < uint64(time.Millisecond): - return fmt.Sprintf("%.2fus", float64(u)/1000) - default: - return fmt.Sprintf("%.2fms", float64(u)/1000/1000) - } - } else { - switch { - case u < uint64(time.Minute): - return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) - case u < uint64(time.Hour): - return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) - default: - return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) - } - } - -} diff --git a/toolbox/profile_test.go b/core/admin/profile_test.go similarity index 98% rename from toolbox/profile_test.go rename to core/admin/profile_test.go index 07a20c4eea..139c4b9908 100644 --- a/toolbox/profile_test.go +++ b/core/admin/profile_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package toolbox +package admin import ( "os" diff --git a/core/bean/context.go b/core/bean/context.go new file mode 100644 index 0000000000..7866643708 --- /dev/null +++ b/core/bean/context.go @@ -0,0 +1,19 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +// ApplicationContext define for future +// when we decide to support DI, IoC, this will be core API +type ApplicationContext interface{} diff --git a/core/bean/doc.go b/core/bean/doc.go new file mode 100644 index 0000000000..049fe5639d --- /dev/null +++ b/core/bean/doc.go @@ -0,0 +1,17 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package bean is a basic package +// it should not depend on other modules except common module, log module and config module +package bean diff --git a/core/bean/factory.go b/core/bean/factory.go new file mode 100644 index 0000000000..1097604c5e --- /dev/null +++ b/core/bean/factory.go @@ -0,0 +1,25 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" +) + +// AutoWireBeanFactory wire the bean based on ApplicationContext and context.Context +type AutoWireBeanFactory interface { + // AutoWire will wire the bean. + AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error +} diff --git a/core/bean/metadata.go b/core/bean/metadata.go new file mode 100644 index 0000000000..e2e34f55c4 --- /dev/null +++ b/core/bean/metadata.go @@ -0,0 +1,28 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +// BeanMetadata, in other words, bean's config. +// it could be read from config file +type BeanMetadata struct { + // Fields: field name => field metadata + Fields map[string]*FieldMetadata +} + +// FieldMetadata contains metadata +type FieldMetadata struct { + // default value in string format + DftValue string +} diff --git a/core/bean/mock.go b/core/bean/mock.go new file mode 100644 index 0000000000..86d90c74c4 --- /dev/null +++ b/core/bean/mock.go @@ -0,0 +1,142 @@ +package bean + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +// Mock have a mock object ,it must be pointer of struct +// the element in mock object can be slices, structures, basic data types, pointers and interface +func Mock(v interface{}) (err error) { + pv := reflect.ValueOf(v) + // the input must be pointer of struct + if pv.Kind() != reflect.Ptr || pv.IsNil() { + err = fmt.Errorf("not a pointer of struct") + return + } + err = mock(pv) + return +} + +func mock(pv reflect.Value) (err error) { + pt := pv.Type() + for i := 0; i < pt.Elem().NumField(); i++ { + ptt := pt.Elem().Field(i) + pvv := pv.Elem().FieldByName(ptt.Name) + if !pvv.CanSet() { + continue + } + kt := ptt.Type.Kind() + tagValue := ptt.Tag.Get("mock") + switch kt { + case reflect.Map: + continue + case reflect.Interface: + if pvv.IsNil() { // when interface is nil,can not sure the type + continue + } + pvv.Set(reflect.New(pvv.Elem().Type().Elem())) + err = mock(pvv.Elem()) + case reflect.Ptr: + err = mockPtr(pvv, ptt.Type.Elem()) + case reflect.Struct: + err = mock(pvv.Addr()) + case reflect.Array, reflect.Slice: + err = mockSlice(tagValue, pvv) + case reflect.String: + pvv.SetString(tagValue) + case reflect.Bool: + err = mockBool(tagValue, pvv) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value, e := strconv.ParseInt(tagValue, 10, 64) + if e != nil || pvv.OverflowInt(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetInt(value) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + value, e := strconv.ParseUint(tagValue, 10, 64) + if e != nil || pvv.OverflowUint(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetUint(value) + case reflect.Float32, reflect.Float64: + value, e := strconv.ParseFloat(tagValue, pvv.Type().Bits()) + if e != nil || pvv.OverflowFloat(value) { + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + pvv.SetFloat(value) + default: + } + if err != nil { + return + } + } + return +} + +// mock slice value +func mockSlice(tagValue string, pvv reflect.Value) (err error) { + if tagValue == "" { + return + } + sliceMetas := strings.Split(tagValue, ":") + if len(sliceMetas) != 2 || sliceMetas[0] != "length" { + err = fmt.Errorf("the value:%s is invalid", tagValue) + return + } + length, e := strconv.Atoi(sliceMetas[1]) + if e != nil { + return e + } + + sliceType := reflect.SliceOf(pvv.Type().Elem()) // get slice type + itemType := sliceType.Elem() // get the type of item in slice + value := reflect.MakeSlice(sliceType, 0, length) + newSliceValue := make([]reflect.Value, 0, length) + for k := 0; k < length; k++ { + itemValue := reflect.New(itemType).Elem() + // if item in slice is struct or pointer,must set zero value + switch itemType.Kind() { + case reflect.Struct: + err = mock(itemValue.Addr()) + case reflect.Ptr: + if itemValue.IsNil() { + itemValue.Set(reflect.New(itemType.Elem())) + if e := mock(itemValue); e != nil { + return e + } + } + } + newSliceValue = append(newSliceValue, itemValue) + if err != nil { + return + } + } + value = reflect.Append(value, newSliceValue...) + pvv.Set(value) + return +} + +// mock bool value +func mockBool(tagValue string, pvv reflect.Value) (err error) { + switch tagValue { + case "true": + pvv.SetBool(true) + case "false": + pvv.SetBool(false) + default: + err = fmt.Errorf("the value:%s is invalid", tagValue) + } + return +} + +// mock pointer +func mockPtr(pvv reflect.Value, ptt reflect.Type) (err error) { + if pvv.IsNil() { + pvv.Set(reflect.New(ptt)) // must set nil value to zero value + } + err = mock(pvv) + return +} diff --git a/core/bean/mock_test.go b/core/bean/mock_test.go new file mode 100644 index 0000000000..0fe81f435f --- /dev/null +++ b/core/bean/mock_test.go @@ -0,0 +1,74 @@ +package bean + +import ( + "fmt" + "testing" +) + +func TestMock(t *testing.T) { + type MockSubSubObject struct { + A int `mock:"20"` + } + type MockSubObjectAnoy struct { + Anoy int `mock:"20"` + } + type MockSubObject struct { + A bool `mock:"true"` + B MockSubSubObject + } + type MockObject struct { + A string `mock:"aaaaa"` + B int8 `mock:"10"` + C []*MockSubObject `mock:"length:2"` + D bool `mock:"true"` + E *MockSubObject + F []int `mock:"length:3"` + G InterfaceA + H InterfaceA + MockSubObjectAnoy + } + m := &MockObject{G: &ImplA{}} + err := Mock(m) + if err != nil { + t.Fatalf("mock failed: %v", err) + } + if m.A != "aaaaa" || m.B != 10 || m.C[1].B.A != 20 || + !m.E.A || m.E.B.A != 20 || !m.D || len(m.F) != 3 { + t.Fail() + } + _, ok := m.G.(*ImplA) + if !ok { + t.Fail() + } + _, ok = m.G.(*ImplB) + if ok { + t.Fail() + } + _, ok = m.H.(*ImplA) + if ok { + t.Fail() + } + if m.Anoy != 20 { + t.Fail() + } +} + +type InterfaceA interface { + Item() +} + +type ImplA struct { + A string `mock:"aaa"` +} + +func (i *ImplA) Item() { + fmt.Println("implA") +} + +type ImplB struct { + B string `mock:"bbb"` +} + +func (i *ImplB) Item() { + fmt.Println("implB") +} diff --git a/core/bean/tag_auto_wire_bean_factory.go b/core/bean/tag_auto_wire_bean_factory.go new file mode 100644 index 0000000000..71426ecc15 --- /dev/null +++ b/core/bean/tag_auto_wire_bean_factory.go @@ -0,0 +1,225 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "fmt" + "reflect" + "strconv" + + "github.com/beego/beego/v2/core/logs" +) + +const DefaultValueTagKey = "default" + +// TagAutoWireBeanFactory wire the bean based on Fields' tag +// if field's value is "zero value", we will execute injection +// see reflect.Value.IsZero() +// If field's kind is one of(reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice +// reflect.UnsafePointer, reflect.Array, reflect.Uintptr, reflect.Complex64, reflect.Complex128 +// reflect.Ptr, reflect.Struct), +// it will be ignored +type TagAutoWireBeanFactory struct { + // we allow user register their TypeAdapter + Adapters map[string]TypeAdapter + + // FieldTagParser is an extension point which means that you can custom how to read field's metadata from tag + FieldTagParser func(field reflect.StructField) *FieldMetadata +} + +// NewTagAutoWireBeanFactory create an instance of TagAutoWireBeanFactory +// by default, we register Time adapter, the time will be parse by using layout "2006-01-02 15:04:05" +// If you need more adapter, you can implement interface TypeAdapter +func NewTagAutoWireBeanFactory() *TagAutoWireBeanFactory { + return &TagAutoWireBeanFactory{ + Adapters: map[string]TypeAdapter{ + "Time": &TimeTypeAdapter{Layout: "2006-01-02 15:04:05"}, + }, + + FieldTagParser: func(field reflect.StructField) *FieldMetadata { + return &FieldMetadata{ + DftValue: field.Tag.Get(DefaultValueTagKey), + } + }, + } +} + +// AutoWire use value from appCtx to wire the bean, or use default value, or do nothing +func (t *TagAutoWireBeanFactory) AutoWire(ctx context.Context, appCtx ApplicationContext, bean interface{}) error { + if bean == nil { + return nil + } + + v := reflect.Indirect(reflect.ValueOf(bean)) + + bm := t.getConfig(v) + + // field name, field metadata + for fn, fm := range bm.Fields { + + fValue := v.FieldByName(fn) + if len(fm.DftValue) == 0 || !t.needInject(fValue) || !fValue.CanSet() { + continue + } + + // handle type adapter + typeName := fValue.Type().Name() + if adapter, ok := t.Adapters[typeName]; ok { + dftValue, err := adapter.DefaultValue(ctx, fm.DftValue) + if err == nil { + fValue.Set(reflect.ValueOf(dftValue)) + continue + } else { + return err + } + } + + switch fValue.Kind() { + case reflect.Bool: + if v, err := strconv.ParseBool(fm.DftValue); err != nil { + return fmt.Errorf("can not convert the field[%s]'s default value[%s] to bool value: %w", + fn, fm.DftValue, err) + } else { + fValue.SetBool(v) + continue + } + case reflect.Int: + if err := t.setIntXValue(fm.DftValue, 0, fn, fValue); err != nil { + return err + } + continue + case reflect.Int8: + if err := t.setIntXValue(fm.DftValue, 8, fn, fValue); err != nil { + return err + } + continue + case reflect.Int16: + if err := t.setIntXValue(fm.DftValue, 16, fn, fValue); err != nil { + return err + } + continue + + case reflect.Int32: + if err := t.setIntXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + + case reflect.Int64: + if err := t.setIntXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint: + if err := t.setUIntXValue(fm.DftValue, 0, fn, fValue); err != nil { + return err + } + + case reflect.Uint8: + if err := t.setUIntXValue(fm.DftValue, 8, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint16: + if err := t.setUIntXValue(fm.DftValue, 16, fn, fValue); err != nil { + return err + } + continue + case reflect.Uint32: + if err := t.setUIntXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + + case reflect.Uint64: + if err := t.setUIntXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.Float32: + if err := t.setFloatXValue(fm.DftValue, 32, fn, fValue); err != nil { + return err + } + continue + case reflect.Float64: + if err := t.setFloatXValue(fm.DftValue, 64, fn, fValue); err != nil { + return err + } + continue + + case reflect.String: + fValue.SetString(fm.DftValue) + continue + + // case reflect.Ptr: + // case reflect.Struct: + default: + logs.Warn("this field[%s] has default setting, but we don't support this type: %s", + fn, fValue.Kind().String()) + } + } + return nil +} + +func (t *TagAutoWireBeanFactory) setFloatXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseFloat(dftValue, bitSize); err != nil { + return fmt.Errorf("can not convert the field[%s]'s default value[%s] to float%d value: %w", + fn, dftValue, bitSize, err) + } else { + fv.SetFloat(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) setUIntXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseUint(dftValue, 10, bitSize); err != nil { + return fmt.Errorf("can not convert the field[%s]'s default value[%s] to uint%d value: %w", + fn, dftValue, bitSize, err) + } else { + fv.SetUint(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) setIntXValue(dftValue string, bitSize int, fn string, fv reflect.Value) error { + if v, err := strconv.ParseInt(dftValue, 10, bitSize); err != nil { + return fmt.Errorf("can not convert the field[%s]'s default value[%s] to int%d value: %w", + fn, dftValue, bitSize, err) + } else { + fv.SetInt(v) + return nil + } +} + +func (t *TagAutoWireBeanFactory) needInject(fValue reflect.Value) bool { + return fValue.IsZero() +} + +// getConfig never return nil +func (t *TagAutoWireBeanFactory) getConfig(beanValue reflect.Value) *BeanMetadata { + fms := make(map[string]*FieldMetadata, beanValue.NumField()) + for i := 0; i < beanValue.NumField(); i++ { + // f => StructField + f := beanValue.Type().Field(i) + fms[f.Name] = t.FieldTagParser(f) + } + return &BeanMetadata{ + Fields: fms, + } +} diff --git a/core/bean/tag_auto_wire_bean_factory_test.go b/core/bean/tag_auto_wire_bean_factory_test.go new file mode 100644 index 0000000000..b5744af7fb --- /dev/null +++ b/core/bean/tag_auto_wire_bean_factory_test.go @@ -0,0 +1,76 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTagAutoWireBeanFactory_AutoWire(t *testing.T) { + factory := NewTagAutoWireBeanFactory() + bm := &ComplicateStruct{} + err := factory.AutoWire(context.Background(), nil, bm) + assert.Nil(t, err) + assert.Equal(t, 12, bm.IntValue) + assert.Equal(t, "hello, strValue", bm.StrValue) + + assert.Equal(t, int8(8), bm.Int8Value) + assert.Equal(t, int16(16), bm.Int16Value) + assert.Equal(t, int32(32), bm.Int32Value) + assert.Equal(t, int64(64), bm.Int64Value) + + assert.Equal(t, uint(13), bm.UintValue) + assert.Equal(t, uint8(88), bm.Uint8Value) + assert.Equal(t, uint16(1616), bm.Uint16Value) + assert.Equal(t, uint32(3232), bm.Uint32Value) + assert.Equal(t, uint64(6464), bm.Uint64Value) + + assert.Equal(t, float32(32.32), bm.Float32Value) + assert.Equal(t, float64(64.64), bm.Float64Value) + + assert.True(t, bm.BoolValue) + assert.Equal(t, 0, bm.ignoreInt) + + assert.NotNil(t, bm.TimeValue) +} + +type ComplicateStruct struct { + BoolValue bool `default:"true"` + Int8Value int8 `default:"8"` + Uint8Value uint8 `default:"88"` + + Int16Value int16 `default:"16"` + Uint16Value uint16 `default:"1616"` + Int32Value int32 `default:"32"` + Uint32Value uint32 `default:"3232"` + + IntValue int `default:"12"` + UintValue uint `default:"13"` + Int64Value int64 `default:"64"` + Uint64Value uint64 `default:"6464"` + + StrValue string `default:"hello, strValue"` + + Float32Value float32 `default:"32.32"` + Float64Value float64 `default:"64.64"` + + ignoreInt int `default:"11"` + + TimeValue time.Time `default:"2018-02-03 12:13:14.000"` +} diff --git a/core/bean/time_type_adapter.go b/core/bean/time_type_adapter.go new file mode 100644 index 0000000000..b0e99896ed --- /dev/null +++ b/core/bean/time_type_adapter.go @@ -0,0 +1,35 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "time" +) + +// TimeTypeAdapter process the time.Time +type TimeTypeAdapter struct { + Layout string +} + +// DefaultValue parse the DftValue to time.Time +// and if the DftValue == now +// time.Now() is returned +func (t *TimeTypeAdapter) DefaultValue(ctx context.Context, dftValue string) (interface{}, error) { + if dftValue == "now" { + return time.Now(), nil + } + return time.Parse(t.Layout, dftValue) +} diff --git a/core/bean/time_type_adapter_test.go b/core/bean/time_type_adapter_test.go new file mode 100644 index 0000000000..140ef5a633 --- /dev/null +++ b/core/bean/time_type_adapter_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTimeTypeAdapter_DefaultValue(t *testing.T) { + typeAdapter := &TimeTypeAdapter{Layout: "2006-01-02 15:04:05"} + tm, err := typeAdapter.DefaultValue(context.Background(), "2018-02-03 12:34:11") + assert.Nil(t, err) + assert.NotNil(t, tm) +} diff --git a/core/bean/type_adapter.go b/core/bean/type_adapter.go new file mode 100644 index 0000000000..5869032d3e --- /dev/null +++ b/core/bean/type_adapter.go @@ -0,0 +1,26 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bean + +import ( + "context" +) + +// TypeAdapter is an abstraction that define some behavior of target type +// usually, we don't use this to support basic type since golang has many restriction for basic types +// This is an important extension point +type TypeAdapter interface { + DefaultValue(ctx context.Context, dftValue string) (interface{}, error) +} diff --git a/core/berror/codes.go b/core/berror/codes.go new file mode 100644 index 0000000000..bee8bb1962 --- /dev/null +++ b/core/berror/codes.go @@ -0,0 +1,87 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" + "sync" +) + +// A Code is an unsigned 32-bit error code as defined in the beego spec. +type Code interface { + Code() uint32 + Module() string + Desc() string + Name() string +} + +var defaultCodeRegistry = &codeRegistry{ + codes: make(map[uint32]*codeDefinition, 127), +} + +// DefineCode defining a new Code +// Before defining a new code, please read Beego specification. +// desc could be markdown doc +func DefineCode(code uint32, module string, name string, desc string) Code { + res := &codeDefinition{ + code: code, + name: name, + module: module, + desc: desc, + } + defaultCodeRegistry.lock.Lock() + defer defaultCodeRegistry.lock.Unlock() + + if _, ok := defaultCodeRegistry.codes[code]; ok { + panic(fmt.Sprintf("duplicate code, code %d has been registered", code)) + } + defaultCodeRegistry.codes[code] = res + return res +} + +type codeRegistry struct { + lock sync.RWMutex + codes map[uint32]*codeDefinition +} + +func (cr *codeRegistry) Get(code uint32) (Code, bool) { + cr.lock.RLock() + defer cr.lock.RUnlock() + c, ok := cr.codes[code] + return c, ok +} + +type codeDefinition struct { + code uint32 + module string + desc string + name string +} + +func (c *codeDefinition) Name() string { + return c.name +} + +func (c *codeDefinition) Code() uint32 { + return c.code +} + +func (c *codeDefinition) Module() string { + return c.module +} + +func (c *codeDefinition) Desc() string { + return c.desc +} diff --git a/core/berror/error.go b/core/berror/error.go new file mode 100644 index 0000000000..a501de44e3 --- /dev/null +++ b/core/berror/error.go @@ -0,0 +1,67 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" + "strconv" + "strings" +) + +// code, msg +const errFmt = "ERROR-%d, %s" + +// Error returns an error representing c and msg. If c is OK, returns nil. +func Error(c Code, msg string) error { + return fmt.Errorf(errFmt, c.Code(), msg) +} + +// Errorf returns error +func Errorf(c Code, format string, a ...interface{}) error { + return Error(c, fmt.Sprintf(format, a...)) +} + +func Wrap(err error, c Code, msg string) error { + if err == nil { + return nil + } + return fmt.Errorf(errFmt+": %w", c.Code(), msg, err) +} + +func Wrapf(err error, c Code, format string, a ...interface{}) error { + return Wrap(err, c, fmt.Sprintf(format, a...)) +} + +// FromError is very simple. It just parse error msg and check whether code has been register +// if code not being register, return unknown +// if err.Error() is not valid beego error code, return unknown +func FromError(err error) (Code, bool) { + msg := err.Error() + codeSeg := strings.SplitN(msg, ",", 2) + if strings.HasPrefix(codeSeg[0], "ERROR-") { + codeStr := strings.SplitN(codeSeg[0], "-", 2) + if len(codeStr) < 2 { + return Unknown, false + } + codeInt, e := strconv.ParseUint(codeStr[1], 10, 32) + if e != nil { + return Unknown, false + } + if code, ok := defaultCodeRegistry.Get(uint32(codeInt)); ok { + return code, true + } + } + return Unknown, false +} diff --git a/core/berror/error_test.go b/core/berror/error_test.go new file mode 100644 index 0000000000..7a14d933cb --- /dev/null +++ b/core/berror/error_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +var testCode1 = DefineCode(1, "unit_test", "TestError", "Hello, test code1") + +var testErr = errors.New("hello, this is error") + +func TestErrorf(t *testing.T) { + msg := Errorf(testCode1, "errorf %s", "aaaa") + assert.NotNil(t, msg) + assert.Equal(t, "ERROR-1, errorf aaaa", msg.Error()) +} + +func TestWrapf(t *testing.T) { + err := Wrapf(testErr, testCode1, "Wrapf %s", "aaaa") + assert.NotNil(t, err) + assert.True(t, errors.Is(err, testErr)) +} + +func TestFromError(t *testing.T) { + err := errors.New("ERROR-1, errorf aaaa") + code, ok := FromError(err) + assert.True(t, ok) + assert.Equal(t, testCode1, code) + assert.Equal(t, "unit_test", code.Module()) + assert.Equal(t, "Hello, test code1", code.Desc()) + + err = errors.New("not beego error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-2, not register") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-aaa, invalid code") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("aaaaaaaaaaaaaa") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR-2-3, invalid error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) + + err = errors.New("ERROR, invalid error") + code, ok = FromError(err) + assert.False(t, ok) + assert.Equal(t, Unknown, code) +} diff --git a/core/berror/pre_define_code.go b/core/berror/pre_define_code.go new file mode 100644 index 0000000000..ff8eb46b63 --- /dev/null +++ b/core/berror/pre_define_code.go @@ -0,0 +1,52 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package berror + +import ( + "fmt" +) + +// pre define code + +// Unknown indicates got some error which is not defined +var Unknown = DefineCode(5000001, "error", "Unknown", fmt.Sprintf(` +Unknown error code. Usually you will see this code in three cases: +1. You forget to define Code or function DefineCode not being executed; +2. This is not Beego's error but you call FromError(); +3. Beego got unexpected error and don't know how to handle it, and then return Unknown error + +A common practice to DefineCode looks like: +%s + +In this way, you may forget to import this package, and got Unknown error. + +Sometimes, you believe you got Beego error, but actually you don't, and then you call FromError(err) + +`, goCodeBlock(` +import your_package + +func init() { + DefineCode(5100100, "your_module", "detail") + // ... +} +`))) + +func goCodeBlock(code string) string { + return codeBlock("go", code) +} + +func codeBlock(lan string, code string) string { + return fmt.Sprintf("```%s\n%s\n```", lan, code) +} diff --git a/core/config/base_config_test.go b/core/config/base_config_test.go new file mode 100644 index 0000000000..340d05db01 --- /dev/null +++ b/core/config/base_config_test.go @@ -0,0 +1,71 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBaseConfiger_DefaultBool(t *testing.T) { + bc := newBaseConfier("true") + assert.True(t, bc.DefaultBool("key1", false)) + assert.True(t, bc.DefaultBool("key2", true)) +} + +func TestBaseConfiger_DefaultFloat(t *testing.T) { + bc := newBaseConfier("12.3") + assert.Equal(t, 12.3, bc.DefaultFloat("key1", 0.1)) + assert.Equal(t, 0.1, bc.DefaultFloat("key2", 0.1)) +} + +func TestBaseConfiger_DefaultInt(t *testing.T) { + bc := newBaseConfier("10") + assert.Equal(t, 10, bc.DefaultInt("key1", 8)) + assert.Equal(t, 8, bc.DefaultInt("key2", 8)) +} + +func TestBaseConfiger_DefaultInt64(t *testing.T) { + bc := newBaseConfier("64") + assert.Equal(t, int64(64), bc.DefaultInt64("key1", int64(8))) + assert.Equal(t, int64(8), bc.DefaultInt64("key2", int64(8))) +} + +func TestBaseConfiger_DefaultString(t *testing.T) { + bc := newBaseConfier("Hello") + assert.Equal(t, "Hello", bc.DefaultString("key1", "world")) + assert.Equal(t, "world", bc.DefaultString("key2", "world")) +} + +func TestBaseConfiger_DefaultStrings(t *testing.T) { + bc := newBaseConfier("Hello;world") + assert.Equal(t, []string{"Hello", "world"}, bc.DefaultStrings("key1", []string{"world"})) + assert.Equal(t, []string{"world"}, bc.DefaultStrings("key2", []string{"world"})) +} + +func newBaseConfier(str1 string) *BaseConfiger { + return &BaseConfiger{ + reader: func(ctx context.Context, key string) (string, error) { + if key == "key1" { + return str1, nil + } else { + return "", errors.New("mock error") + } + }, + } +} diff --git a/config/config.go b/core/config/config.go similarity index 56% rename from config/config.go rename to core/config/config.go index bfd79e85da..145333f93a 100644 --- a/config/config.go +++ b/core/config/config.go @@ -14,59 +14,188 @@ // Package config is used to parse config. // Usage: -// import "github.com/astaxie/beego/config" -//Examples. // -// cnf, err := config.NewConfig("ini", "config.conf") +// import "github.com/beego/beego/v2/core/config" // -// cnf APIS: +// Examples. // -// cnf.Set(key, val string) error -// cnf.String(key string) string -// cnf.Strings(key string) []string -// cnf.Int(key string) (int, error) -// cnf.Int64(key string) (int64, error) -// cnf.Bool(key string) (bool, error) -// cnf.Float(key string) (float64, error) -// cnf.DefaultString(key string, defaultVal string) string -// cnf.DefaultStrings(key string, defaultVal []string) []string -// cnf.DefaultInt(key string, defaultVal int) int -// cnf.DefaultInt64(key string, defaultVal int64) int64 -// cnf.DefaultBool(key string, defaultVal bool) bool -// cnf.DefaultFloat(key string, defaultVal float64) float64 -// cnf.DIY(key string) (interface{}, error) -// cnf.GetSection(section string) (map[string]string, error) -// cnf.SaveConfigFile(filename string) error -//More docs http://beego.me/docs/module/config.md +// cnf, err := config.NewConfig("ini", "config.conf") +// +// cnf APIS: +// +// cnf.Set(key, val string) error +// cnf.String(key string) string +// cnf.Strings(key string) []string +// cnf.Int(key string) (int, error) +// cnf.Int64(key string) (int64, error) +// cnf.Bool(key string) (bool, error) +// cnf.Float(key string) (float64, error) +// cnf.DefaultString(key string, defaultVal string) string +// cnf.DefaultStrings(key string, defaultVal []string) []string +// cnf.DefaultInt(key string, defaultVal int) int +// cnf.DefaultInt64(key string, defaultVal int64) int64 +// cnf.DefaultBool(key string, defaultVal bool) bool +// cnf.DefaultFloat(key string, defaultVal float64) float64 +// cnf.DIY(key string) (interface{}, error) +// cnf.GetSection(section string) (map[string]string, error) +// cnf.SaveConfigFile(filename string) error package config import ( + "context" + "errors" "fmt" "os" "reflect" + "strconv" + "strings" "time" ) // Configer defines how to get and set value from configuration raw data. type Configer interface { - Set(key, val string) error //support section::key type in given key when using ini type. - String(key string) string //support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - Strings(key string) []string //get string slice + // Set support section::key type in given key when using ini type. + Set(key, val string) error + + // String support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + String(key string) (string, error) + // Strings get string slice + Strings(key string) ([]string, error) Int(key string) (int, error) Int64(key string) (int64, error) Bool(key string) (bool, error) Float(key string) (float64, error) - DefaultString(key string, defaultVal string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. - DefaultStrings(key string, defaultVal []string) []string //get string slice + // DefaultString support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. + DefaultString(key string, defaultVal string) string + // DefaultStrings get string slice + DefaultStrings(key string, defaultVal []string) []string DefaultInt(key string, defaultVal int) int DefaultInt64(key string, defaultVal int64) int64 DefaultBool(key string, defaultVal bool) bool DefaultFloat(key string, defaultVal float64) float64 + + // DIY return the original value DIY(key string) (interface{}, error) + GetSection(section string) (map[string]string, error) + + Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error + Sub(key string) (Configer, error) + OnChange(key string, fn func(value string)) SaveConfigFile(filename string) error } +type BaseConfiger struct { + // The reader should support key like "a.b.c" + reader func(ctx context.Context, key string) (string, error) +} + +func NewBaseConfiger(reader func(ctx context.Context, key string) (string, error)) BaseConfiger { + return BaseConfiger{ + reader: reader, + } +} + +func (c *BaseConfiger) Int(key string) (int, error) { + res, err := c.reader(context.TODO(), key) + if err != nil { + return 0, err + } + return strconv.Atoi(res) +} + +func (c *BaseConfiger) Int64(key string) (int64, error) { + res, err := c.reader(context.TODO(), key) + if err != nil { + return 0, err + } + return strconv.ParseInt(res, 10, 64) +} + +func (c *BaseConfiger) Bool(key string) (bool, error) { + res, err := c.reader(context.TODO(), key) + if err != nil { + return false, err + } + return ParseBool(res) +} + +func (c *BaseConfiger) Float(key string) (float64, error) { + res, err := c.reader(context.TODO(), key) + if err != nil { + return 0, err + } + return strconv.ParseFloat(res, 64) +} + +// DefaultString returns the string value for a given key. +// if err != nil or value is empty return defaultval +func (c *BaseConfiger) DefaultString(key string, defaultVal string) string { + if res, err := c.String(key); res != "" && err == nil { + return res + } + return defaultVal +} + +// DefaultStrings returns the []string value for a given key. +// if err != nil return defaultval +func (c *BaseConfiger) DefaultStrings(key string, defaultVal []string) []string { + if res, err := c.Strings(key); len(res) > 0 && err == nil { + return res + } + return defaultVal +} + +func (c *BaseConfiger) DefaultInt(key string, defaultVal int) int { + if res, err := c.Int(key); err == nil { + return res + } + return defaultVal +} + +func (c *BaseConfiger) DefaultInt64(key string, defaultVal int64) int64 { + if res, err := c.Int64(key); err == nil { + return res + } + return defaultVal +} + +func (c *BaseConfiger) DefaultBool(key string, defaultVal bool) bool { + if res, err := c.Bool(key); err == nil { + return res + } + return defaultVal +} + +func (c *BaseConfiger) DefaultFloat(key string, defaultVal float64) float64 { + if res, err := c.Float(key); err == nil { + return res + } + return defaultVal +} + +func (c *BaseConfiger) String(key string) (string, error) { + return c.reader(context.TODO(), key) +} + +// Strings returns the []string value for a given key. +// Return nil if config value does not exist or is empty. +func (c *BaseConfiger) Strings(key string) ([]string, error) { + res, err := c.String(key) + if err != nil || res == "" { + return nil, err + } + return strings.Split(res, ";"), nil +} + +func (*BaseConfiger) Sub(string) (Configer, error) { + return nil, errors.New("unsupported operation") +} + +func (*BaseConfiger) OnChange(_ string, _ func(value string)) { + // do nothing +} + // Config is the adapter interface for parsing config file to get raw data to Configer. type Config interface { Parse(key string) (Configer, error) @@ -121,6 +250,12 @@ func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { value[k2] = ExpandValueEnv(v2) } m[k] = value + case map[interface{}]interface{}: + tmp := make(map[string]interface{}, len(value)) + for k2, v2 := range value { + tmp[k2.(string)] = v2 + } + m[k] = ExpandValueEnvForMap(tmp) } } return m @@ -133,6 +268,7 @@ func ExpandValueEnvForMap(m map[string]interface{}) map[string]interface{} { // // It accept value formats "${env}" , "${env||}}" , "${env||defaultValue}" , "defaultvalue". // Examples: +// // v1 := config.ExpandValueEnv("${GOPATH}") // return the GOPATH environment variable. // v2 := config.ExpandValueEnv("${GOAsta||/usr/local/go}") // return the default value "/usr/local/go/". // v3 := config.ExpandValueEnv("Astaxie") // return the value "Astaxie". @@ -240,3 +376,7 @@ func ToString(x interface{}) string { // Fallback to fmt package for anything else like numeric types return fmt.Sprint(x) } + +type DecodeOption func(options decodeOptions) + +type decodeOptions struct{} diff --git a/config/config_test.go b/core/config/config_test.go similarity index 99% rename from config/config_test.go rename to core/config/config_test.go index 15d6ffa615..86d3a2c590 100644 --- a/config/config_test.go +++ b/core/config/config_test.go @@ -20,7 +20,6 @@ import ( ) func TestExpandValueEnv(t *testing.T) { - testCases := []struct { item string want string @@ -51,5 +50,4 @@ func TestExpandValueEnv(t *testing.T) { t.Errorf("expand value error, item %q want %q, got %q", c.item, c.want, got) } } - } diff --git a/core/config/env/env.go b/core/config/env/env.go new file mode 100644 index 0000000000..8d3b2ae295 --- /dev/null +++ b/core/config/env/env.go @@ -0,0 +1,176 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// Copyright 2017 Faissal Elamraoui. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package env is used to parse environment. +package env + +import ( + "fmt" + "go/build" + "os" + "path/filepath" + "strings" + "sync" +) + +var env sync.Map + +func init() { + for _, e := range os.Environ() { + splits := strings.Split(e, "=") + env.Store(splits[0], os.Getenv(splits[0])) + } + env.Store("GOBIN", GetGOBIN()) // set non-empty GOBIN when initialization + env.Store("GOPATH", GetGOPATH()) // set non-empty GOPATH when initialization +} + +// Get returns a value for a given key. +// If the key does not exist, the default value will be returned. +func Get(key string, defVal string) string { + if val, ok := env.Load(key); ok { + return val.(string) + } + return defVal +} + +// MustGet returns a value by key. +// If the key does not exist, it will return an error. +func MustGet(key string) (string, error) { + if val, ok := env.Load(key); ok { + return val.(string), nil + } + return "", fmt.Errorf("no env variable with %s", key) +} + +// Set sets a value in the ENV copy. +// This does not affect the child process environment. +func Set(key string, value string) { + env.Store(key, value) +} + +// MustSet sets a value in the ENV copy and the child process environment. +// It returns an error in case the set operation failed. +func MustSet(key string, value string) error { + err := os.Setenv(key, value) + if err != nil { + return err + } + env.Store(key, value) + return nil +} + +// GetAll returns all keys/values in the current child process environment. +func GetAll() map[string]string { + envs := make(map[string]string, 32) + env.Range(func(key, value interface{}) bool { + switch key := key.(type) { + case string: + switch val := value.(type) { + case string: + envs[key] = val + } + } + return true + }) + return envs +} + +// envFile returns the name of the Go environment configuration file. +// Copy from https://github.com/golang/go/blob/c4f2a9788a7be04daf931ac54382fbe2cb754938/src/cmd/go/internal/cfg/cfg.go#L150-L166 +func envFile() (string, error) { + if file := os.Getenv("GOENV"); file != "" { + if file == "off" { + return "", fmt.Errorf("GOENV=off") + } + return file, nil + } + dir, err := os.UserConfigDir() + if err != nil { + return "", err + } + if dir == "" { + return "", fmt.Errorf("missing user-config dir") + } + return filepath.Join(dir, "go", "env"), nil +} + +// GetRuntimeEnv returns the value of runtime environment variable, +// that is set by running following command: `go env -w key=value`. +func GetRuntimeEnv(key string) (string, error) { + file, err := envFile() + if err != nil { + return "", err + } + if file == "" { + return "", fmt.Errorf("missing runtime env file") + } + + var runtimeEnv string + data, err := os.ReadFile(file) + if err != nil { + return "", err + } + envStrings := strings.Split(string(data), "\n") + for _, envItem := range envStrings { + envItem = strings.TrimSuffix(envItem, "\r") + envKeyValue := strings.Split(envItem, "=") + if len(envKeyValue) == 2 && strings.TrimSpace(envKeyValue[0]) == key { + runtimeEnv = strings.TrimSpace(envKeyValue[1]) + } + } + return runtimeEnv, nil +} + +// GetGOBIN returns GOBIN environment variable as a string. +// It will NOT be an empty string. +func GetGOBIN() string { + // The one set by user explicitly by `export GOBIN=/path` or `env GOBIN=/path command` + gobin := strings.TrimSpace(Get("GOBIN", "")) + if gobin == "" { + var err error + // The one set by user by running `go env -w GOBIN=/path` + gobin, err = GetRuntimeEnv("GOBIN") + if err != nil { + // The default one that Golang uses + return filepath.Join(build.Default.GOPATH, "bin") + } + if gobin == "" { + return filepath.Join(build.Default.GOPATH, "bin") + } + return gobin + } + return gobin +} + +// GetGOPATH returns GOPATH environment variable as a string. +// It will NOT be an empty string. +func GetGOPATH() string { + // The one set by user explicitly by `export GOPATH=/path` or `env GOPATH=/path command` + gopath := strings.TrimSpace(Get("GOPATH", "")) + if gopath == "" { + var err error + // The one set by user by running `go env -w GOPATH=/path` + gopath, err = GetRuntimeEnv("GOPATH") + if err != nil { + // The default one that Golang uses + return build.Default.GOPATH + } + if gopath == "" { + return build.Default.GOPATH + } + return gopath + } + return gopath +} diff --git a/config/env/env_test.go b/core/config/env/env_test.go similarity index 57% rename from config/env/env_test.go rename to core/config/env/env_test.go index 3f1d4dbab2..eabff737ad 100644 --- a/config/env/env_test.go +++ b/core/config/env/env_test.go @@ -5,7 +5,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -15,7 +15,9 @@ package env import ( + "go/build" "os" + "path/filepath" "testing" ) @@ -73,3 +75,41 @@ func TestEnvGetAll(t *testing.T) { t.Error("expected environment not empty.") } } + +func TestEnvFile(t *testing.T) { + envFile, err := envFile() + if err != nil { + t.Errorf("expected to get env file without error, but got %s.", err) + } + if envFile == "" { + t.Error("expected to get valid env file, but got empty string.") + } +} + +func TestGetGOBIN(t *testing.T) { + customGOBIN := filepath.Join("path", "to", "gobin") + Set("GOBIN", customGOBIN) + if gobin := GetGOBIN(); gobin != Get("GOBIN", "") { + t.Errorf("expected GOBIN environment variable equals to %s, but got %s.", customGOBIN, gobin) + } + + Set("GOBIN", "") + defaultGOBIN := filepath.Join(build.Default.GOPATH, "bin") + if gobin := GetGOBIN(); gobin != defaultGOBIN { + t.Errorf("expected GOBIN environment variable equals to %s, but got %s.", defaultGOBIN, gobin) + } +} + +func TestGetGOPATH(t *testing.T) { + customGOPATH := filepath.Join("path", "to", "gopath") + Set("GOPATH", customGOPATH) + if goPath := GetGOPATH(); goPath != Get("GOPATH", "") { + t.Errorf("expected GOPATH environment variable equals to %s, but got %s.", customGOPATH, goPath) + } + + Set("GOPATH", "") + defaultGOPATH := build.Default.GOPATH + if goPath := GetGOPATH(); goPath != defaultGOPATH { + t.Errorf("expected GOPATH environment variable equals to %s, but got %s.", defaultGOPATH, goPath) + } +} diff --git a/core/config/error.go b/core/config/error.go new file mode 100644 index 0000000000..728163aa58 --- /dev/null +++ b/core/config/error.go @@ -0,0 +1,25 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "errors" +) + +// now not all implementation return those error codes +var ( + KeyNotFoundError = errors.New("the key is not found") + InvalidValueTypeError = errors.New("the value is not expected type") +) diff --git a/core/config/etcd/config.go b/core/config/etcd/config.go new file mode 100644 index 0000000000..0516fea49c --- /dev/null +++ b/core/config/etcd/config.go @@ -0,0 +1,192 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/mitchellh/mapstructure" + clientv3 "go.etcd.io/etcd/client/v3" + "google.golang.org/grpc" + + "github.com/beego/beego/v2/core/config" + "github.com/beego/beego/v2/core/logs" +) + +type EtcdConfiger struct { + prefix string + client *clientv3.Client + config.BaseConfiger +} + +func newEtcdConfiger(client *clientv3.Client, prefix string) *EtcdConfiger { + res := &EtcdConfiger{ + client: client, + prefix: prefix, + } + + res.BaseConfiger = config.NewBaseConfiger(res.reader) + return res +} + +// reader is a general implementation that read config from etcd. +func (e *EtcdConfiger) reader(ctx context.Context, key string) (string, error) { + resp, err := get(e.client, e.prefix+key) + if err != nil { + return "", err + } + + if resp.Count > 0 { + return string(resp.Kvs[0].Value), nil + } + + return "", nil +} + +// Set do nothing and return an error +// I think write data to remote config center is not a good practice +func (e *EtcdConfiger) Set(key, val string) error { + return errors.New("Unsupported operation") +} + +// DIY return the original response from etcd +// be careful when you decide to use this +func (e *EtcdConfiger) DIY(key string) (interface{}, error) { + return get(e.client, key) +} + +// GetSection in this implementation, we use section as prefix +func (e *EtcdConfiger) GetSection(section string) (map[string]string, error) { + var ( + resp *clientv3.GetResponse + err error + ) + + resp, err = e.client.Get(context.TODO(), e.prefix+section, clientv3.WithPrefix()) + + if err != nil { + return nil, fmt.Errorf("GetSection failed: %w", err) + } + res := make(map[string]string, len(resp.Kvs)) + for _, kv := range resp.Kvs { + res[string(kv.Key)] = string(kv.Value) + } + return res, nil +} + +func (e *EtcdConfiger) SaveConfigFile(filename string) error { + return errors.New("Unsupported operation") +} + +// Unmarshaler is not very powerful because we lost the type information when we get configuration from etcd +// for example, when we got "5", we are not sure whether it's int 5, or it's string "5" +// TODO(support more complicated decoder) +func (e *EtcdConfiger) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + res, err := e.GetSection(prefix) + if err != nil { + return fmt.Errorf("could not read config with prefix: %s: %w", prefix, err) + } + + prefixLen := len(e.prefix + prefix) + m := make(map[string]string, len(res)) + for k, v := range res { + m[k[prefixLen:]] = v + } + return mapstructure.Decode(m, obj) +} + +// Sub return an sub configer. +func (e *EtcdConfiger) Sub(key string) (config.Configer, error) { + return newEtcdConfiger(e.client, e.prefix+key), nil +} + +// TODO remove this before release v2.0.0 +func (e *EtcdConfiger) OnChange(key string, fn func(value string)) { + buildOptsFunc := func() []clientv3.OpOption { + return []clientv3.OpOption{} + } + + rch := e.client.Watch(context.Background(), e.prefix+key, buildOptsFunc()...) + go func() { + for { + for resp := range rch { + if err := resp.Err(); err != nil { + logs.Error("listen to key but got error callback", err) + break + } + + for _, e := range resp.Events { + if e.Kv == nil { + continue + } + fn(string(e.Kv.Value)) + } + } + time.Sleep(time.Second) + rch = e.client.Watch(context.Background(), e.prefix+key, buildOptsFunc()...) + } + }() +} + +type EtcdConfigerProvider struct{} + +// Parse = ParseData([]byte(key)) +// key must be json +func (provider *EtcdConfigerProvider) Parse(key string) (config.Configer, error) { + return provider.ParseData([]byte(key)) +} + +// ParseData try to parse key as clientv3.Config, using this to build etcdClient +func (provider *EtcdConfigerProvider) ParseData(data []byte) (config.Configer, error) { + cfg := &clientv3.Config{} + err := json.Unmarshal(data, cfg) + if err != nil { + return nil, fmt.Errorf("parse data to etcd config failed, please check your input: %w", err) + } + + cfg.DialOptions = []grpc.DialOption{ + grpc.WithBlock(), + grpc.WithUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor), + grpc.WithStreamInterceptor(grpc_prometheus.StreamClientInterceptor), + } + client, err := clientv3.New(*cfg) + if err != nil { + return nil, fmt.Errorf("create etcd client failed: %w", err) + } + + return newEtcdConfiger(client, ""), nil +} + +func get(client *clientv3.Client, key string) (*clientv3.GetResponse, error) { + var ( + resp *clientv3.GetResponse + err error + ) + resp, err = client.Get(context.Background(), key) + + if err != nil { + return nil, fmt.Errorf("read config from etcd with key %s failed: %w", key, err) + } + return resp, err +} + +func init() { + config.Register("etcd", &EtcdConfigerProvider{}) +} diff --git a/core/config/etcd/config_test.go b/core/config/etcd/config_test.go new file mode 100644 index 0000000000..ffaf272734 --- /dev/null +++ b/core/config/etcd/config_test.go @@ -0,0 +1,116 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package etcd + +import ( + "encoding/json" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + clientv3 "go.etcd.io/etcd/client/v3" +) + +func TestEtcdConfigerProvider_Parse(t *testing.T) { + provider := &EtcdConfigerProvider{} + cfger, err := provider.Parse(readEtcdConfig()) + assert.Nil(t, err) + assert.NotNil(t, cfger) +} + +func TestEtcdConfiger(t *testing.T) { + provider := &EtcdConfigerProvider{} + cfger, _ := provider.Parse(readEtcdConfig()) + + subCfger, err := cfger.Sub("sub.") + assert.Nil(t, err) + assert.NotNil(t, subCfger) + + subSubCfger, err := subCfger.Sub("sub.") + assert.NotNil(t, subSubCfger) + assert.Nil(t, err) + + str, err := subSubCfger.String("key1") + assert.Nil(t, err) + assert.Equal(t, "sub.sub.key", str) + + // we cannot test it + subSubCfger.OnChange("watch", func(value string) { + // do nothing + }) + + defStr := cfger.DefaultString("not_exit", "default value") + assert.Equal(t, "default value", defStr) + + defInt64 := cfger.DefaultInt64("not_exit", -1) + assert.Equal(t, int64(-1), defInt64) + + defInt := cfger.DefaultInt("not_exit", -2) + assert.Equal(t, -2, defInt) + + defFlt := cfger.DefaultFloat("not_exit", 12.3) + assert.Equal(t, 12.3, defFlt) + + defBl := cfger.DefaultBool("not_exit", true) + assert.True(t, defBl) + + defStrs := cfger.DefaultStrings("not_exit", []string{"hello"}) + assert.Equal(t, []string{"hello"}, defStrs) + + fl, err := cfger.Float("current.float") + assert.Nil(t, err) + assert.Equal(t, 1.23, fl) + + bl, err := cfger.Bool("current.bool") + assert.Nil(t, err) + assert.True(t, bl) + + it, err := cfger.Int("current.int") + assert.Nil(t, err) + assert.Equal(t, 11, it) + + str, err = cfger.String("current.string") + assert.Nil(t, err) + assert.Equal(t, "hello", str) + + tn := &TestEntity{} + err = cfger.Unmarshaler("current.serialize.", tn) + assert.Nil(t, err) + assert.Equal(t, "test", tn.Name) +} + +type TestEntity struct { + Name string `yaml:"name"` + Sub SubEntity `yaml:"sub"` +} + +type SubEntity struct { + SubName string `yaml:"subName"` +} + +func readEtcdConfig() string { + addr := os.Getenv("ETCD_ADDR") + if addr == "" { + addr = "localhost:2379" + } + + obj := clientv3.Config{ + Endpoints: []string{addr}, + DialTimeout: 3 * time.Second, + } + cfg, _ := json.Marshal(obj) + return string(cfg) +} diff --git a/config/fake.go b/core/config/fake.go similarity index 71% rename from config/fake.go rename to core/config/fake.go index d21ab820dc..3f6f46827c 100644 --- a/config/fake.go +++ b/core/config/fake.go @@ -15,12 +15,14 @@ package config import ( + "context" "errors" "strconv" "strings" ) type fakeConfigContainer struct { + BaseConfiger data map[string]string } @@ -33,42 +35,14 @@ func (c *fakeConfigContainer) Set(key, val string) error { return nil } -func (c *fakeConfigContainer) String(key string) string { - return c.getData(key) -} - -func (c *fakeConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval - } - return v -} - -func (c *fakeConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil - } - return strings.Split(v, ";") -} - -func (c *fakeConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval - } - return v -} - func (c *fakeConfigContainer) Int(key string) (int, error) { return strconv.Atoi(c.getData(key)) } -func (c *fakeConfigContainer) DefaultInt(key string, defaultval int) int { +func (c *fakeConfigContainer) DefaultInt(key string, defaultVal int) int { v, err := c.Int(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -77,10 +51,10 @@ func (c *fakeConfigContainer) Int64(key string) (int64, error) { return strconv.ParseInt(c.getData(key), 10, 64) } -func (c *fakeConfigContainer) DefaultInt64(key string, defaultval int64) int64 { +func (c *fakeConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { v, err := c.Int64(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -89,10 +63,10 @@ func (c *fakeConfigContainer) Bool(key string) (bool, error) { return ParseBool(c.getData(key)) } -func (c *fakeConfigContainer) DefaultBool(key string, defaultval bool) bool { +func (c *fakeConfigContainer) DefaultBool(key string, defaultVal bool) bool { v, err := c.Bool(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -101,10 +75,10 @@ func (c *fakeConfigContainer) Float(key string) (float64, error) { return strconv.ParseFloat(c.getData(key), 64) } -func (c *fakeConfigContainer) DefaultFloat(key string, defaultval float64) float64 { +func (c *fakeConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { v, err := c.Float(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -124,11 +98,19 @@ func (c *fakeConfigContainer) SaveConfigFile(filename string) error { return errors.New("not implement in the fakeConfigContainer") } +func (c *fakeConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { + return errors.New("unsupported operation") +} + var _ Configer = new(fakeConfigContainer) // NewFakeConfig return a fake Configer func NewFakeConfig() Configer { - return &fakeConfigContainer{ + res := &fakeConfigContainer{ data: make(map[string]string), } + res.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) { + return res.getData(key), nil + }) + return res } diff --git a/core/config/global.go b/core/config/global.go new file mode 100644 index 0000000000..6f692fce54 --- /dev/null +++ b/core/config/global.go @@ -0,0 +1,111 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +// We use this to simply application's development +// for most users, they only need to use those methods +var globalInstance Configer + +// InitGlobalInstance will ini the global instance +// If you want to use specific implementation, don't forget to import it. +// e.g. _ import "github.com/beego/beego/v2/core/config/etcd" +// err := InitGlobalInstance("etcd", "someconfig") +func InitGlobalInstance(name string, cfg string) error { + var err error + globalInstance, err = NewConfig(name, cfg) + return err +} + +// Set support section::key type in given key when using ini type. +func Set(key, val string) error { + return globalInstance.Set(key, val) +} + +// String support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. +func String(key string) (string, error) { + return globalInstance.String(key) +} + +// Strings will get string slice +func Strings(key string) ([]string, error) { + return globalInstance.Strings(key) +} + +func Int(key string) (int, error) { + return globalInstance.Int(key) +} + +func Int64(key string) (int64, error) { + return globalInstance.Int64(key) +} + +func Bool(key string) (bool, error) { + return globalInstance.Bool(key) +} + +func Float(key string) (float64, error) { + return globalInstance.Float(key) +} + +// DefaultString support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same. +func DefaultString(key string, defaultVal string) string { + return globalInstance.DefaultString(key, defaultVal) +} + +// DefaultStrings will get string slice +func DefaultStrings(key string, defaultVal []string) []string { + return globalInstance.DefaultStrings(key, defaultVal) +} + +func DefaultInt(key string, defaultVal int) int { + return globalInstance.DefaultInt(key, defaultVal) +} + +func DefaultInt64(key string, defaultVal int64) int64 { + return globalInstance.DefaultInt64(key, defaultVal) +} + +func DefaultBool(key string, defaultVal bool) bool { + return globalInstance.DefaultBool(key, defaultVal) +} + +func DefaultFloat(key string, defaultVal float64) float64 { + return globalInstance.DefaultFloat(key, defaultVal) +} + +// DIY return the original value +func DIY(key string) (interface{}, error) { + return globalInstance.DIY(key) +} + +func GetSection(section string) (map[string]string, error) { + return globalInstance.GetSection(section) +} + +func Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { + return globalInstance.Unmarshaler(prefix, obj, opt...) +} + +func Sub(key string) (Configer, error) { + return globalInstance.Sub(key) +} + +func OnChange(key string, fn func(value string)) { + globalInstance.OnChange(key, fn) +} + +func SaveConfigFile(filename string) error { + return globalInstance.SaveConfigFile(filename) +} diff --git a/core/config/global_test.go b/core/config/global_test.go new file mode 100644 index 0000000000..ff01b043e2 --- /dev/null +++ b/core/config/global_test.go @@ -0,0 +1,104 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGlobalInstance(t *testing.T) { + cfgStr := ` +appname = beeapi +httpport = 8080 +mysqlport = 3600 +PI = 3.1415926 +runmode = "dev" +autorender = false +copyrequestbody = true +session= on +cookieon= off +newreg = OFF +needlogin = ON +enableSession = Y +enableCookie = N +developer="tom;jerry" +flag = 1 +path1 = ${GOPATH} +path2 = ${GOPATH||/home/go} +[demo] +key1="asta" +key2 = "xie" +CaseInsensitive = true +peers = one;two;three +password = ${GOPATH} +` + path := os.TempDir() + string(os.PathSeparator) + "test_global_instance.ini" + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(cfgStr) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove(path) + + err = InitGlobalInstance("ini", path) + assert.Nil(t, err) + + val, err := String("appname") + assert.Nil(t, err) + assert.Equal(t, "beeapi", val) + + val = DefaultString("appname__", "404") + assert.Equal(t, "404", val) + + vi, err := Int("httpport") + assert.Nil(t, err) + assert.Equal(t, 8080, vi) + vi = DefaultInt("httpport__", 404) + assert.Equal(t, 404, vi) + + vi64, err := Int64("mysqlport") + assert.Nil(t, err) + assert.Equal(t, int64(3600), vi64) + vi64 = DefaultInt64("mysqlport__", 404) + assert.Equal(t, int64(404), vi64) + + vf, err := Float("PI") + assert.Nil(t, err) + assert.Equal(t, 3.1415926, vf) + vf = DefaultFloat("PI__", 4.04) + assert.Equal(t, 4.04, vf) + + vb, err := Bool("copyrequestbody") + assert.Nil(t, err) + assert.True(t, vb) + + vb = DefaultBool("copyrequestbody__", false) + assert.False(t, vb) + + vss := DefaultStrings("developer__", []string{"tom", ""}) + assert.Equal(t, []string{"tom", ""}, vss) + + vss, err = Strings("developer") + assert.Nil(t, err) + assert.Equal(t, []string{"tom", "jerry"}, vss) +} diff --git a/config/ini.go b/core/config/ini.go similarity index 84% rename from config/ini.go rename to core/config/ini.go index 002e5e0566..35784bce73 100644 --- a/config/ini.go +++ b/core/config/ini.go @@ -17,15 +17,18 @@ package config import ( "bufio" "bytes" + "context" "errors" + "fmt" "io" - "io/ioutil" "os" "os/user" "path/filepath" "strconv" "strings" "sync" + + "github.com/mitchellh/mapstructure" ) var ( @@ -41,8 +44,7 @@ var ( ) // IniConfig implements Config to parse ini file. -type IniConfig struct { -} +type IniConfig struct{} // Parse creates a new Config and parses the file configuration from the named file. func (ini *IniConfig) Parse(name string) (Configer, error) { @@ -50,7 +52,7 @@ func (ini *IniConfig) Parse(name string) (Configer, error) { } func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { - data, err := ioutil.ReadFile(name) + data, err := os.ReadFile(name) if err != nil { return nil, err } @@ -65,6 +67,10 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e keyComment: make(map[string]string), RWMutex: sync.RWMutex{}, } + + cfg.BaseConfiger = NewBaseConfiger(func(ctx context.Context, key string) (string, error) { + return cfg.getdata(key), nil + }) cfg.Lock() defer cfg.Unlock() @@ -90,7 +96,7 @@ func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, e break } - //It might be a good idea to throw a error on all unknonw errors? + // It might be a good idea to throw an error on all unknonw errors? if _, ok := err.(*os.PathError); ok { return nil, err } @@ -222,9 +228,10 @@ func (ini *IniConfig) ParseData(data []byte) (Configer, error) { return ini.parseData(dir, data) } -// IniConfigContainer A Config represents the ini configuration. +// IniConfigContainer is a config which represents the ini configuration. // When set and get value, support key as section:name type. type IniConfigContainer struct { + BaseConfiger data map[string]map[string]string // section=> key:val sectionComment map[string]string // section : comment keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. @@ -237,11 +244,11 @@ func (c *IniConfigContainer) Bool(key string) (bool, error) { } // DefaultBool returns the boolean value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultBool(key string, defaultval bool) bool { +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultBool(key string, defaultVal bool) bool { v, err := c.Bool(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -252,11 +259,11 @@ func (c *IniConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultInt(key string, defaultval int) int { +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultInt(key string, defaultVal int) int { v, err := c.Int(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -267,11 +274,11 @@ func (c *IniConfigContainer) Int64(key string) (int64, error) { } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultInt64(key string, defaultval int64) int64 { +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { v, err := c.Int64(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -282,46 +289,46 @@ func (c *IniConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultFloat(key string, defaultval float64) float64 { +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { v, err := c.Float(key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *IniConfigContainer) String(key string) string { - return c.getdata(key) +func (c *IniConfigContainer) String(key string) (string, error) { + return c.getdata(key), nil } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultString(key string, defaultVal string) string { + v, err := c.String(key) + if v == "" || err != nil { + return defaultVal } return v } // Strings returns the []string value for a given key. // Return nil if config value does not exist or is empty. -func (c *IniConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil +func (c *IniConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) + if v == "" || err != nil { + return nil, err } - return strings.Split(v, ";") + return strings.Split(v, ";"), nil } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *IniConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval +// if err != nil return defaultVal +func (c *IniConfigContainer) DefaultStrings(key string, defaultVal []string) []string { + v, err := c.Strings(key) + if v == nil || err != nil { + return defaultVal } return v } @@ -437,7 +444,7 @@ func (c *IniConfigContainer) SaveConfigFile(filename string) (err error) { // Set writes a new value for key. // if write to one section, the key need be "section::key". // if the section is not existed, it panics. -func (c *IniConfigContainer) Set(key, value string) error { +func (c *IniConfigContainer) Set(key, val string) error { c.Lock() defer c.Unlock() if len(key) == 0 { @@ -460,7 +467,7 @@ func (c *IniConfigContainer) Set(key, value string) error { if _, ok := c.data[section]; !ok { c.data[section] = make(map[string]string) } - c.data[section][k] = value + c.data[section][k] = val return nil } @@ -469,7 +476,7 @@ func (c *IniConfigContainer) DIY(key string) (v interface{}, err error) { if v, ok := c.data[strings.ToLower(key)]; ok { return v, nil } - return v, errors.New("key not find") + return v, errors.New("key not found") } // section.key or key @@ -499,6 +506,20 @@ func (c *IniConfigContainer) getdata(key string) string { return "" } +func (c *IniConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...DecodeOption) error { + if len(prefix) > 0 { + return errors.New("unsupported prefix params") + } + return mapstructure.Decode(c.data, obj) +} + func init() { Register("ini", &IniConfig{}) + + err := InitGlobalInstance("ini", "conf/app.conf") + if err != nil { + fmt.Fprint(os.Stderr, "init global config instance failed. If you do not use this, just ignore it. ", err) + } } + +// Ignore this error diff --git a/config/ini_test.go b/core/config/ini_test.go similarity index 94% rename from config/ini_test.go rename to core/config/ini_test.go index ffcdb294af..302c7ecd84 100644 --- a/config/ini_test.go +++ b/core/config/ini_test.go @@ -16,14 +16,12 @@ package config import ( "fmt" - "io/ioutil" "os" "strings" "testing" ) func TestIni(t *testing.T) { - var ( inicontext = ` ;comment one @@ -109,9 +107,9 @@ password = ${GOPATH} case bool: value, err = iniconf.Bool(k) case []string: - value = iniconf.Strings(k) + value, err = iniconf.Strings(k) case string: - value = iniconf.String(k) + value, err = iniconf.String(k) default: value, err = iniconf.DIY(k) } @@ -125,14 +123,13 @@ password = ${GOPATH} if err = iniconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - if iniconf.String("name") != "astaxie" { + res, _ := iniconf.String("name") + if res != "astaxie" { t.Fatal("get name error") } - } func TestIniSave(t *testing.T) { - const ( inicontext = ` app = app @@ -145,7 +142,7 @@ httpport = 8080 # enable db [dbinfo] # db type name -# suport mysql,sqlserver +# support mysql,sqlserver name = mysql ` @@ -161,7 +158,7 @@ httpport=8080 # enable db [dbinfo] # db type name -# suport mysql,sqlserver +# support mysql,sqlserver name=mysql ` ) @@ -175,7 +172,7 @@ name=mysql } defer os.Remove(name) - if data, err := ioutil.ReadFile(name); err != nil { + if data, err := os.ReadFile(name); err != nil { t.Fatal(err) } else { cfgData := string(data) diff --git a/config/json.go b/core/config/json/json.go similarity index 68% rename from config/json.go rename to core/config/json/json.go index 74c18c9c12..1d764733b4 100644 --- a/config/json.go +++ b/core/config/json/json.go @@ -12,30 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -package config +package json import ( "encoding/json" "errors" "fmt" - "io/ioutil" + "io" "os" + "strconv" "strings" "sync" + + "github.com/mitchellh/mapstructure" + + "github.com/beego/beego/v2/core/config" + "github.com/beego/beego/v2/core/logs" ) // JSONConfig is a json config parser and implements Config interface. -type JSONConfig struct { -} +type JSONConfig struct{} // Parse returns a ConfigContainer with parsed json config map. -func (js *JSONConfig) Parse(filename string) (Configer, error) { +func (js *JSONConfig) Parse(filename string) (config.Configer, error) { file, err := os.Open(filename) if err != nil { return nil, err } defer file.Close() - content, err := ioutil.ReadAll(file) + content, err := io.ReadAll(file) if err != nil { return nil, err } @@ -44,7 +49,7 @@ func (js *JSONConfig) Parse(filename string) (Configer, error) { } // ParseData returns a ConfigContainer with json string -func (js *JSONConfig) ParseData(data []byte) (Configer, error) { +func (js *JSONConfig) ParseData(data []byte) (config.Configer, error) { x := &JSONConfigContainer{ data: make(map[string]interface{}), } @@ -58,34 +63,72 @@ func (js *JSONConfig) ParseData(data []byte) (Configer, error) { x.data["rootArray"] = wrappingArray } - x.data = ExpandValueEnvForMap(x.data) + x.data = config.ExpandValueEnvForMap(x.data) return x, nil } -// JSONConfigContainer A Config represents the json configuration. +// JSONConfigContainer is a config which represents the json configuration. // Only when get value, support key as section:name type. type JSONConfigContainer struct { data map[string]interface{} sync.RWMutex } +func (c *JSONConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(prefix) + if err != nil { + return err + } + return mapstructure.Decode(sub, obj) +} + +func (c *JSONConfigContainer) Sub(key string) (config.Configer, error) { + sub, err := c.sub(key) + if err != nil { + return nil, err + } + return &JSONConfigContainer{ + data: sub, + }, nil +} + +func (c *JSONConfigContainer) sub(key string) (map[string]interface{}, error) { + if key == "" { + return c.data, nil + } + value, ok := c.data[key] + if !ok { + return nil, fmt.Errorf("key is not found: %s", key) + } + + res, ok := value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("the type of value is invalid, key: %s", key) + } + return res, nil +} + +func (c *JSONConfigContainer) OnChange(key string, fn func(value string)) { + logs.Warn("unsupported operation") +} + // Bool returns the boolean value for a given key. func (c *JSONConfigContainer) Bool(key string) (bool, error) { val := c.getData(key) if val != nil { - return ParseBool(val) + return config.ParseBool(val) } return false, fmt.Errorf("not exist key: %q", key) } // DefaultBool return the bool value if has no error // otherwise return the defaultval -func (c *JSONConfigContainer) DefaultBool(key string, defaultval bool) bool { +func (c *JSONConfigContainer) DefaultBool(key string, defaultVal bool) bool { if v, err := c.Bool(key); err == nil { return v } - return defaultval + return defaultVal } // Int returns the integer value for a given key. @@ -94,19 +137,21 @@ func (c *JSONConfigContainer) Int(key string) (int, error) { if val != nil { if v, ok := val.(float64); ok { return int(v), nil + } else if v, ok := val.(string); ok { + return strconv.Atoi(v) } - return 0, errors.New("not int value") + return 0, errors.New("not valid value") } return 0, errors.New("not exist key:" + key) } // DefaultInt returns the integer value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultInt(key string, defaultval int) int { +func (c *JSONConfigContainer) DefaultInt(key string, defaultVal int) int { if v, err := c.Int(key); err == nil { return v } - return defaultval + return defaultVal } // Int64 returns the int64 value for a given key. @@ -123,11 +168,11 @@ func (c *JSONConfigContainer) Int64(key string) (int64, error) { // DefaultInt64 returns the int64 value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultInt64(key string, defaultval int64) int64 { +func (c *JSONConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { if v, err := c.Int64(key); err == nil { return v } - return defaultval + return defaultVal } // Float returns the float value for a given key. @@ -144,50 +189,49 @@ func (c *JSONConfigContainer) Float(key string) (float64, error) { // DefaultFloat returns the float64 value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultFloat(key string, defaultval float64) float64 { +func (c *JSONConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { if v, err := c.Float(key); err == nil { return v } - return defaultval + return defaultVal } // String returns the string value for a given key. -func (c *JSONConfigContainer) String(key string) string { +func (c *JSONConfigContainer) String(key string) (string, error) { val := c.getData(key) if val != nil { if v, ok := val.(string); ok { - return v + return v, nil } } - return "" + return "", nil } // DefaultString returns the string value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultString(key string, defaultval string) string { - // TODO FIXME should not use "" to replace non existence - if v := c.String(key); v != "" { +func (c *JSONConfigContainer) DefaultString(key string, defaultVal string) string { + if v, err := c.String(key); v != "" && err == nil { return v } - return defaultval + return defaultVal } // Strings returns the []string value for a given key. -func (c *JSONConfigContainer) Strings(key string) []string { - stringVal := c.String(key) - if stringVal == "" { - return nil +func (c *JSONConfigContainer) Strings(key string) ([]string, error) { + stringVal, err := c.String(key) + if stringVal == "" || err != nil { + return nil, err } - return strings.Split(c.String(key), ";") + return strings.Split(stringVal, ";"), nil } // DefaultStrings returns the []string value for a given key. // if err != nil return defaultval -func (c *JSONConfigContainer) DefaultStrings(key string, defaultval []string) []string { - if v := c.Strings(key); v != nil { +func (c *JSONConfigContainer) DefaultStrings(key string, defaultVal []string) []string { + if v, err := c.Strings(key); v != nil && err == nil { return v } - return defaultval + return defaultVal } // GetSection returns map for the given section @@ -262,5 +306,5 @@ func (c *JSONConfigContainer) getData(key string) interface{} { } func init() { - Register("json", &JSONConfig{}) + config.Register("json", &JSONConfig{}) } diff --git a/config/json_test.go b/core/config/json/json_test.go similarity index 85% rename from config/json_test.go rename to core/config/json/json_test.go index 16f424095f..2d10718375 100644 --- a/config/json_test.go +++ b/core/config/json/json_test.go @@ -12,16 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -package config +package json import ( "fmt" "os" "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/config" ) func TestJsonStartsWithArray(t *testing.T) { - const jsoncontextwitharray = `[ { "url": "user", @@ -43,7 +46,7 @@ func TestJsonStartsWithArray(t *testing.T) { } f.Close() defer os.Remove("testjsonWithArray.conf") - jsonconf, err := NewConfig("json", "testjsonWithArray.conf") + jsonconf, err := config.NewConfig("json", "testjsonWithArray.conf") if err != nil { t.Fatal(err) } @@ -68,7 +71,6 @@ func TestJsonStartsWithArray(t *testing.T) { } func TestJson(t *testing.T) { - var ( jsoncontext = `{ "appname": "beeapi", @@ -143,7 +145,7 @@ func TestJson(t *testing.T) { } f.Close() defer os.Remove("testjson.conf") - jsonconf, err := NewConfig("json", "testjson.conf") + jsonconf, err := config.NewConfig("json", "testjson.conf") if err != nil { t.Fatal(err) } @@ -161,9 +163,9 @@ func TestJson(t *testing.T) { case bool: value, err = jsonconf.Bool(k) case []string: - value = jsonconf.Strings(k) + value, err = jsonconf.Strings(k) case string: - value = jsonconf.String(k) + value, err = jsonconf.String(k) default: value, err = jsonconf.DIY(k) } @@ -177,7 +179,9 @@ func TestJson(t *testing.T) { if err = jsonconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - if jsonconf.String("name") != "astaxie" { + + res, _ := jsonconf.String("name") + if res != "astaxie" { t.Fatal("get name error") } @@ -208,7 +212,7 @@ func TestJson(t *testing.T) { t.Error("unknown keys should return an error when expecting an interface{}") } - if val := jsonconf.String("unknown"); val != "" { + if val, _ := jsonconf.String("unknown"); val != "" { t.Error("unknown keys should return an empty string when expecting a String") } @@ -219,4 +223,27 @@ func TestJson(t *testing.T) { if !jsonconf.DefaultBool("unknown", true) { t.Error("unknown keys with default value wrong") } + + sub, err := jsonconf.Sub("database") + assert.Nil(t, err) + assert.NotNil(t, sub) + + sub, err = sub.Sub("conns") + assert.Nil(t, err) + + maxCon, _ := sub.Int("maxconnection") + assert.Equal(t, 12, maxCon) + + dbCfg := &DatabaseConfig{} + err = sub.Unmarshaler("", dbCfg) + assert.Nil(t, err) + assert.Equal(t, 12, dbCfg.MaxConnection) + assert.True(t, dbCfg.Autoconnect) + assert.Equal(t, "info", dbCfg.Connectioninfo) +} + +type DatabaseConfig struct { + MaxConnection int `json:"maxconnection"` + Autoconnect bool `json:"autoconnect"` + Connectioninfo string `json:"connectioninfo"` } diff --git a/core/config/toml/toml.go b/core/config/toml/toml.go new file mode 100644 index 0000000000..297d16c8d6 --- /dev/null +++ b/core/config/toml/toml.go @@ -0,0 +1,350 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toml + +import ( + "os" + "strings" + + "github.com/pelletier/go-toml" + + "github.com/beego/beego/v2/core/config" +) + +const keySeparator = "." + +type Config struct { + tree *toml.Tree +} + +// Parse accepts filename as the parameter +func (c *Config) Parse(filename string) (config.Configer, error) { + ctx, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + return c.ParseData(ctx) +} + +func (c *Config) ParseData(data []byte) (config.Configer, error) { + t, err := toml.LoadBytes(data) + if err != nil { + return nil, err + } + return &configContainer{ + t: t, + }, nil +} + +// configContainer support key looks like "a.b.c" +type configContainer struct { + t *toml.Tree +} + +// Set put key, val +func (c *configContainer) Set(key, val string) error { + path := strings.Split(key, keySeparator) + sub, err := subTree(c.t, path[0:len(path)-1]) + if err != nil { + return err + } + sub.Set(path[len(path)-1], val) + return nil +} + +// String return the value. +// return error if key not found or value is invalid type +func (c *configContainer) String(key string) (string, error) { + res, err := c.get(key) + if err != nil { + return "", err + } + + if res == nil { + return "", config.KeyNotFoundError + } + + if str, ok := res.(string); ok { + return str, nil + } else { + return "", config.InvalidValueTypeError + } +} + +// Strings return []string +// return error if key not found or value is invalid type +func (c *configContainer) Strings(key string) ([]string, error) { + val, err := c.get(key) + if err != nil { + return []string{}, err + } + if val == nil { + return []string{}, config.KeyNotFoundError + } + if arr, ok := val.([]interface{}); ok { + res := make([]string, 0, len(arr)) + for _, ele := range arr { + if str, ok := ele.(string); ok { + res = append(res, str) + } else { + return []string{}, config.InvalidValueTypeError + } + } + return res, nil + } else { + return []string{}, config.InvalidValueTypeError + } +} + +// Int return int value +// return error if key not found or value is invalid type +func (c *configContainer) Int(key string) (int, error) { + val, err := c.Int64(key) + return int(val), err +} + +// Int64 return int64 value +// return error if key not found or value is invalid type +func (c *configContainer) Int64(key string) (int64, error) { + res, err := c.get(key) + if err != nil { + return 0, err + } + if res == nil { + return 0, config.KeyNotFoundError + } + if i, ok := res.(int); ok { + return int64(i), nil + } else if i64, ok := res.(int64); ok { + return i64, nil + } else { + return 0, config.InvalidValueTypeError + } +} + +// Bool return bool value +// return error if key not found or value is invalid type +func (c *configContainer) Bool(key string) (bool, error) { + res, err := c.get(key) + if err != nil { + return false, err + } + + if res == nil { + return false, config.KeyNotFoundError + } + if b, ok := res.(bool); ok { + return b, nil + } else { + return false, config.InvalidValueTypeError + } +} + +// Float return float value +// return error if key not found or value is invalid type +func (c *configContainer) Float(key string) (float64, error) { + res, err := c.get(key) + if err != nil { + return 0, err + } + + if res == nil { + return 0, config.KeyNotFoundError + } + + if f, ok := res.(float64); ok { + return f, nil + } else { + return 0, config.InvalidValueTypeError + } +} + +// DefaultString return string value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultString(key string, defaultVal string) string { + res, err := c.get(key) + if err != nil { + return defaultVal + } + if str, ok := res.(string); ok { + return str + } else { + return defaultVal + } +} + +// DefaultStrings return []string +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultStrings(key string, defaultVal []string) []string { + val, err := c.get(key) + if err != nil { + return defaultVal + } + if arr, ok := val.([]interface{}); ok { + res := make([]string, 0, len(arr)) + for _, ele := range arr { + if str, ok := ele.(string); ok { + res = append(res, str) + } else { + return defaultVal + } + } + return res + } else { + return defaultVal + } +} + +// DefaultInt return int value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultInt(key string, defaultVal int) int { + return int(c.DefaultInt64(key, int64(defaultVal))) +} + +// DefaultInt64 return int64 value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultInt64(key string, defaultVal int64) int64 { + res, err := c.get(key) + if err != nil { + return defaultVal + } + if i, ok := res.(int); ok { + return int64(i) + } else if i64, ok := res.(int64); ok { + return i64 + } else { + return defaultVal + } +} + +// DefaultBool return bool value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultBool(key string, defaultVal bool) bool { + res, err := c.get(key) + if err != nil { + return defaultVal + } + if b, ok := res.(bool); ok { + return b + } else { + return defaultVal + } +} + +// DefaultFloat return float value +// return default value if key not found or value is invalid type +func (c *configContainer) DefaultFloat(key string, defaultVal float64) float64 { + res, err := c.get(key) + if err != nil { + return defaultVal + } + if f, ok := res.(float64); ok { + return f + } else { + return defaultVal + } +} + +// DIY returns the original value +func (c *configContainer) DIY(key string) (interface{}, error) { + return c.get(key) +} + +// GetSection return error if the value is not valid toml doc +func (c *configContainer) GetSection(section string) (map[string]string, error) { + val, err := subTree(c.t, strings.Split(section, keySeparator)) + if err != nil { + return map[string]string{}, err + } + m := val.ToMap() + res := make(map[string]string, len(m)) + for k, v := range m { + res[k] = config.ToString(v) + } + return res, nil +} + +func (c *configContainer) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + if len(prefix) > 0 { + t, err := subTree(c.t, strings.Split(prefix, keySeparator)) + if err != nil { + return err + } + return t.Unmarshal(obj) + } + return c.t.Unmarshal(obj) +} + +// Sub return sub configer +// return error if key not found or the value is not a sub doc +func (c *configContainer) Sub(key string) (config.Configer, error) { + val, err := subTree(c.t, strings.Split(key, keySeparator)) + if err != nil { + return nil, err + } + return &configContainer{ + t: val, + }, nil +} + +// OnChange do nothing +func (c *configContainer) OnChange(key string, fn func(value string)) { + // do nothing +} + +// SaveConfigFile create or override the file +func (c *configContainer) SaveConfigFile(filename string) error { + // Write configuration file by filename. + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + _, err = c.t.WriteTo(f) + return err +} + +func (c *configContainer) get(key string) (interface{}, error) { + if len(key) == 0 { + return nil, config.KeyNotFoundError + } + + segs := strings.Split(key, keySeparator) + t, err := subTree(c.t, segs[0:len(segs)-1]) + if err != nil { + return nil, err + } + return t.Get(segs[len(segs)-1]), nil +} + +func subTree(t *toml.Tree, path []string) (*toml.Tree, error) { + res := t + for i := 0; i < len(path); i++ { + if subTree, ok := res.Get(path[i]).(*toml.Tree); ok { + res = subTree + } else { + return nil, config.InvalidValueTypeError + } + } + if res == nil { + return nil, config.KeyNotFoundError + } + return res, nil +} + +func init() { + config.Register("toml", &Config{}) +} diff --git a/core/config/toml/toml_test.go b/core/config/toml/toml_test.go new file mode 100644 index 0000000000..2867f4be96 --- /dev/null +++ b/core/config/toml/toml_test.go @@ -0,0 +1,379 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toml + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/config" +) + +func TestConfig_Parse(t *testing.T) { + // file not found + cfg := &Config{} + _, err := cfg.Parse("invalid_file_name.txt") + assert.NotNil(t, err) +} + +func TestConfig_ParseData(t *testing.T) { + data := ` +name="Tom" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) +} + +func TestConfigContainer_Bool(t *testing.T) { + data := ` +Man=true +Woman="true" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.Bool("Man") + assert.Nil(t, err) + assert.True(t, val) + + _, err = c.Bool("Woman") + assert.NotNil(t, err) + assert.Equal(t, config.InvalidValueTypeError, err) +} + +func TestConfigContainer_DefaultBool(t *testing.T) { + data := ` +Man=true +Woman="false" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultBool("Man11", true) + assert.True(t, val) + + val = c.DefaultBool("Man", false) + assert.True(t, val) + + val = c.DefaultBool("Woman", true) + assert.True(t, val) +} + +func TestConfigContainer_DefaultFloat(t *testing.T) { + data := ` +Price=12.3 +PriceInvalid="12.3" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultFloat("Price", 11.2) + assert.Equal(t, 12.3, val) + + val = c.DefaultFloat("Price11", 11.2) + assert.Equal(t, 11.2, val) + + val = c.DefaultFloat("PriceInvalid", 11.2) + assert.Equal(t, 11.2, val) +} + +func TestConfigContainer_DefaultInt(t *testing.T) { + data := ` +Age=12 +AgeInvalid="13" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultInt("Age", 11) + assert.Equal(t, 12, val) + + val = c.DefaultInt("Price11", 11) + assert.Equal(t, 11, val) + + val = c.DefaultInt("PriceInvalid", 11) + assert.Equal(t, 11, val) +} + +func TestConfigContainer_DefaultString(t *testing.T) { + data := ` +Name="Tom" +NameInvalid=13 +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultString("Name", "Jerry") + assert.Equal(t, "Tom", val) + + val = c.DefaultString("Name11", "Jerry") + assert.Equal(t, "Jerry", val) + + val = c.DefaultString("NameInvalid", "Jerry") + assert.Equal(t, "Jerry", val) +} + +func TestConfigContainer_DefaultStrings(t *testing.T) { + data := ` +Name=["Tom", "Jerry"] +NameInvalid="Tom" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val := c.DefaultStrings("Name", []string{"Jerry"}) + assert.Equal(t, []string{"Tom", "Jerry"}, val) + + val = c.DefaultStrings("Name11", []string{"Jerry"}) + assert.Equal(t, []string{"Jerry"}, val) + + val = c.DefaultStrings("NameInvalid", []string{"Jerry"}) + assert.Equal(t, []string{"Jerry"}, val) +} + +func TestConfigContainer_DIY(t *testing.T) { + data := ` +Name=["Tom", "Jerry"] +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + _, err = c.DIY("Name") + assert.Nil(t, err) +} + +func TestConfigContainer_Float(t *testing.T) { + data := ` +Price=12.3 +PriceInvalid="12.3" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.Float("Price") + assert.Nil(t, err) + assert.Equal(t, 12.3, val) + + _, err = c.Float("Price11") + assert.Equal(t, config.KeyNotFoundError, err) + + _, err = c.Float("PriceInvalid") + assert.Equal(t, config.InvalidValueTypeError, err) +} + +func TestConfigContainer_Int(t *testing.T) { + data := ` +Age=12 +AgeInvalid="13" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.Int("Age") + assert.Nil(t, err) + assert.Equal(t, 12, val) + + _, err = c.Int("Age11") + assert.Equal(t, config.KeyNotFoundError, err) + + _, err = c.Int("AgeInvalid") + assert.Equal(t, config.InvalidValueTypeError, err) +} + +func TestConfigContainer_GetSection(t *testing.T) { + data := ` +[servers] + + # You can indent as you please. Tabs or spaces. TOML don't care. + [servers.alpha] + ip = "10.0.0.1" + dc = "eqdc10" + + [servers.beta] + ip = "10.0.0.2" + dc = "eqdc10" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + m, err := c.GetSection("servers") + assert.Nil(t, err) + assert.NotNil(t, m) + assert.Equal(t, 2, len(m)) +} + +func TestConfigContainer_String(t *testing.T) { + data := ` +Name="Tom" +NameInvalid=13 +[Person] +Name="Jerry" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.String("Name") + assert.Nil(t, err) + assert.Equal(t, "Tom", val) + + _, err = c.String("Name11") + assert.Equal(t, config.KeyNotFoundError, err) + + _, err = c.String("NameInvalid") + assert.Equal(t, config.InvalidValueTypeError, err) + + val, err = c.String("Person.Name") + assert.Nil(t, err) + assert.Equal(t, "Jerry", val) +} + +func TestConfigContainer_Strings(t *testing.T) { + data := ` +Name=["Tom", "Jerry"] +NameInvalid="Tom" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + val, err := c.Strings("Name") + assert.Nil(t, err) + assert.Equal(t, []string{"Tom", "Jerry"}, val) + + _, err = c.Strings("Name11") + assert.Equal(t, config.KeyNotFoundError, err) + + _, err = c.Strings("NameInvalid") + assert.Equal(t, config.InvalidValueTypeError, err) +} + +func TestConfigContainer_Set(t *testing.T) { + data := ` +Name=["Tom", "Jerry"] +NameInvalid="Tom" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + err = c.Set("Age", "11") + assert.Nil(t, err) + age, err := c.String("Age") + assert.Nil(t, err) + assert.Equal(t, "11", age) +} + +func TestConfigContainer_SubAndMushall(t *testing.T) { + data := ` +[servers] + + # You can indent as you please. Tabs or spaces. TOML don't care. + [servers.alpha] + ip = "10.0.0.1" + dc = "eqdc10" + + [servers.beta] + ip = "10.0.0.2" + dc = "eqdc10" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + assert.Nil(t, err) + assert.NotNil(t, c) + + sub, err := c.Sub("servers") + assert.Nil(t, err) + assert.NotNil(t, sub) + + sub, err = sub.Sub("alpha") + assert.Nil(t, err) + assert.NotNil(t, sub) + ip, err := sub.String("ip") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1", ip) + + svr := &Server{} + err = sub.Unmarshaler("", svr) + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1", svr.Ip) + + svr = &Server{} + err = c.Unmarshaler("servers.alpha", svr) + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1", svr.Ip) +} + +func TestConfigContainer_SaveConfigFile(t *testing.T) { + filename := "test_config.toml" + path := os.TempDir() + string(os.PathSeparator) + filename + data := ` +[servers] + + # You can indent as you please. Tabs or spaces. TOML don't care. + [servers.alpha] + ip = "10.0.0.1" + dc = "eqdc10" + + [servers.beta] + ip = "10.0.0.2" + dc = "eqdc10" +` + cfg := &Config{} + c, err := cfg.ParseData([]byte(data)) + + fmt.Println(path) + + assert.Nil(t, err) + assert.NotNil(t, c) + + sub, err := c.Sub("servers") + assert.Nil(t, err) + + err = sub.SaveConfigFile(path) + assert.Nil(t, err) +} + +type Server struct { + Ip string `toml:"ip"` +} diff --git a/config/xml/xml.go b/core/config/xml/xml.go similarity index 60% rename from config/xml/xml.go rename to core/config/xml/xml.go index 494242d319..8aa7e26bea 100644 --- a/config/xml/xml.go +++ b/core/config/xml/xml.go @@ -19,28 +19,30 @@ // go install github.com/beego/x2j. // // Usage: -// import( -// _ "github.com/astaxie/beego/config/xml" -// "github.com/astaxie/beego/config" -// ) // -// cnf, err := config.NewConfig("xml", "config.xml") +// import( +// _ "github.com/beego/beego/v2/core/config/xml" +// "github.com/beego/beego/v2/core/config" +// ) // -//More docs http://beego.me/docs/module/config.md +// cnf, err := config.NewConfig("xml", "config.xml") package xml import ( "encoding/xml" "errors" "fmt" - "io/ioutil" "os" "strconv" "strings" "sync" - "github.com/astaxie/beego/config" + "github.com/mitchellh/mapstructure" + "github.com/beego/x2j" + + "github.com/beego/beego/v2/core/config" + "github.com/beego/beego/v2/core/logs" ) // Config is a xml config parser and implements Config interface. @@ -50,7 +52,7 @@ type Config struct{} // Parse returns a ConfigContainer with parsed xml config map. func (xc *Config) Parse(filename string) (config.Configer, error) { - context, err := ioutil.ReadFile(filename) + context, err := os.ReadFile(filename) if err != nil { return nil, err } @@ -67,17 +69,69 @@ func (xc *Config) ParseData(data []byte) (config.Configer, error) { return nil, err } - x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) + v := d["config"] + if v == nil { + return nil, fmt.Errorf("xml parse should include in tags") + } + + confVal, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("xml parse tags should include sub tags") + } + + x.data = config.ExpandValueEnvForMap(confVal) return x, nil } -// ConfigContainer A Config represents the xml configuration. +// ConfigContainer is a Config which represents the xml configuration. type ConfigContainer struct { data map[string]interface{} sync.Mutex } +// Unmarshaler is a little be inconvenient since the xml library doesn't know type. +// So when you use +// 1 +// The "1" is a string, not int +func (c *ConfigContainer) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + sub, err := c.sub(prefix) + if err != nil { + return err + } + return mapstructure.Decode(sub, obj) +} + +func (c *ConfigContainer) Sub(key string) (config.Configer, error) { + sub, err := c.sub(key) + if err != nil { + return nil, err + } + + return &ConfigContainer{ + data: sub, + }, nil +} + +func (c *ConfigContainer) sub(key string) (map[string]interface{}, error) { + if key == "" { + return c.data, nil + } + value, ok := c.data[key] + if !ok { + return nil, fmt.Errorf("the key is not found: %s", key) + } + res, ok := value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("the value of this key is not a structure: %s", key) + } + return res, nil +} + +func (c *ConfigContainer) OnChange(key string, fn func(value string)) { + logs.Warn("Unsupported operation") +} + // Bool returns the boolean value for a given key. func (c *ConfigContainer) Bool(key string) (bool, error) { if v := c.data[key]; v != nil { @@ -87,11 +141,11 @@ func (c *ConfigContainer) Bool(key string) (bool, error) { } // DefaultBool return the bool value if has no error -// otherwise return the defaultval -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { +// otherwise return the defaultVal +func (c *ConfigContainer) DefaultBool(key string, defaultVal bool) bool { v, err := c.Bool(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -102,11 +156,11 @@ func (c *ConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt(key string, defaultVal int) int { v, err := c.Int(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -117,14 +171,13 @@ func (c *ConfigContainer) Int64(key string) (int64, error) { } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { v, err := c.Int64(key) if err != nil { - return defaultval + return defaultVal } return v - } // Float returns the float value for a given key. @@ -133,48 +186,48 @@ func (c *ConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { v, err := c.Float(key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *ConfigContainer) String(key string) string { +func (c *ConfigContainer) String(key string) (string, error) { if v, ok := c.data[key].(string); ok { - return v + return v, nil } - return "" + return "", nil } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultString(key string, defaultVal string) string { + v, err := c.String(key) + if v == "" || err != nil { + return defaultVal } return v } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil +func (c *ConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) + if v == "" || err != nil { + return nil, err } - return strings.Split(v, ";") + return strings.Split(v, ";"), nil } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultStrings(key string, defaultVal []string) []string { + v, err := c.Strings(key) + if v == nil || err != nil { + return defaultVal } return v } diff --git a/config/xml/xml_test.go b/core/config/xml/xml_test.go similarity index 67% rename from config/xml/xml_test.go rename to core/config/xml/xml_test.go index 346c866ee0..2807be693f 100644 --- a/config/xml/xml_test.go +++ b/core/config/xml/xml_test.go @@ -19,13 +19,14 @@ import ( "os" "testing" - "github.com/astaxie/beego/config" + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/config" ) func TestXML(t *testing.T) { - var ( - //xml parse should incluce in tags + // xml parse should include in tags xmlcontext = ` beeapi @@ -102,9 +103,9 @@ func TestXML(t *testing.T) { case bool: value, err = xmlconf.Bool(k) case []string: - value = xmlconf.Strings(k) + value, err = xmlconf.Strings(k) case string: - value = xmlconf.String(k) + value, err = xmlconf.String(k) default: value, err = xmlconf.DIY(k) } @@ -119,7 +120,55 @@ func TestXML(t *testing.T) { if err = xmlconf.Set("name", "astaxie"); err != nil { t.Fatal(err) } - if xmlconf.String("name") != "astaxie" { + + res, _ := xmlconf.String("name") + if res != "astaxie" { t.Fatal("get name error") } + + sub, err := xmlconf.Sub("mysection") + assert.Nil(t, err) + assert.NotNil(t, sub) + name, err := sub.String("name") + assert.Nil(t, err) + assert.Equal(t, "MySection", name) + + id, err := sub.Int("id") + assert.Nil(t, err) + assert.Equal(t, 1, id) + + sec := &Section{} + + err = sub.Unmarshaler("", sec) + assert.Nil(t, err) + assert.Equal(t, "MySection", sec.Name) + + sec = &Section{} + + err = xmlconf.Unmarshaler("mysection", sec) + assert.Nil(t, err) + assert.Equal(t, "MySection", sec.Name) +} + +func TestXMLMissConfig(t *testing.T) { + xmlcontext1 := ` + + beeapi + ` + + c := &Config{} + _, err := c.ParseData([]byte(xmlcontext1)) + assert.Equal(t, "xml parse should include in tags", err.Error()) + + xmlcontext2 := ` + + + ` + + _, err = c.ParseData([]byte(xmlcontext2)) + assert.Equal(t, "xml parse tags should include sub tags", err.Error()) +} + +type Section struct { + Name string `xml:"name"` } diff --git a/config/yaml/yaml.go b/core/config/yaml/yaml.go similarity index 57% rename from config/yaml/yaml.go rename to core/config/yaml/yaml.go index 5def2da34e..ea54a1bb8d 100644 --- a/config/yaml/yaml.go +++ b/core/config/yaml/yaml.go @@ -13,42 +13,34 @@ // limitations under the License. // Package yaml for config provider -// -// depend on github.com/beego/goyaml2 -// -// go install github.com/beego/goyaml2 -// // Usage: -// import( -// _ "github.com/astaxie/beego/config/yaml" -// "github.com/astaxie/beego/config" -// ) // -// cnf, err := config.NewConfig("yaml", "config.yaml") +// import( +// _ "github.com/beego/beego/v2/core/config/yaml" +// "github.com/beego/beego/v2/core/config" +// ) // -//More docs http://beego.me/docs/module/config.md +// cnf, err := config.NewConfig("yaml", "config.yaml") package yaml import ( - "bytes" - "encoding/json" "errors" "fmt" - "io/ioutil" - "log" "os" "strings" "sync" - "github.com/astaxie/beego/config" - "github.com/beego/goyaml2" + "gopkg.in/yaml.v3" + + "github.com/beego/beego/v2/core/config" + "github.com/beego/beego/v2/core/logs" ) // Config is a yaml config parser and implements Config interface. type Config struct{} // Parse returns a ConfigContainer with parsed yaml config map. -func (yaml *Config) Parse(filename string) (y config.Configer, err error) { +func (*Config) Parse(filename string) (y config.Configer, err error) { cnf, err := ReadYmlReader(filename) if err != nil { return @@ -60,7 +52,7 @@ func (yaml *Config) Parse(filename string) (y config.Configer, err error) { } // ParseData parse yaml data -func (yaml *Config) ParseData(data []byte) (config.Configer, error) { +func (*Config) ParseData(data []byte) (config.Configer, error) { cnf, err := parseYML(data) if err != nil { return nil, err @@ -72,9 +64,8 @@ func (yaml *Config) ParseData(data []byte) (config.Configer, error) { } // ReadYmlReader Read yaml file to map. -// if json like, use json package, unless goyaml2 package. func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { - buf, err := ioutil.ReadFile(path) + buf, err := os.ReadFile(path) if err != nil { return } @@ -83,43 +74,69 @@ func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { } // parseYML parse yaml formatted []byte to map. -func parseYML(buf []byte) (cnf map[string]interface{}, err error) { - if len(buf) < 3 { - return +func parseYML(buf []byte) (map[string]interface{}, error) { + cnf := make(map[string]interface{}) + err := yaml.Unmarshal(buf, cnf) + if err != nil { + return nil, err } + cnf = config.ExpandValueEnvForMap(cnf) + return cnf, err +} - if string(buf[0:1]) == "{" { - log.Println("Look like a Json, try json umarshal") - err = json.Unmarshal(buf, &cnf) - if err == nil { - log.Println("It is Json Map") - return - } +// ConfigContainer is a config which represents the yaml configuration. +type ConfigContainer struct { + data map[string]interface{} + sync.RWMutex +} + +// Unmarshaler is similar to Sub +func (c *ConfigContainer) Unmarshaler(prefix string, obj interface{}, _ ...config.DecodeOption) error { + sub, err := c.subMap(prefix) + if err != nil { + return err } - data, err := goyaml2.Read(bytes.NewReader(buf)) + bytes, err := yaml.Marshal(sub) if err != nil { - log.Println("Goyaml2 ERR>", string(buf), err) - return + return err } + return yaml.Unmarshal(bytes, obj) +} - if data == nil { - log.Println("Goyaml2 output nil? Pls report bug\n" + string(buf)) - return +func (c *ConfigContainer) Sub(key string) (config.Configer, error) { + sub, err := c.subMap(key) + if err != nil { + return nil, err } - cnf, ok := data.(map[string]interface{}) - if !ok { - log.Println("Not a Map? >> ", string(buf), data) - cnf = nil + return &ConfigContainer{ + data: sub, + }, nil +} + +func (c *ConfigContainer) subMap(key string) (map[string]interface{}, error) { + tmpData := c.data + keys := strings.Split(key, ".") + for idx, k := range keys { + if v, ok := tmpData[k]; ok { + switch val := v.(type) { + case map[string]interface{}: + tmpData = val + if idx == len(keys)-1 { + return tmpData, nil + } + default: + return nil, fmt.Errorf("the key is invalid: %s", key) + } + } } - cnf = config.ExpandValueEnvForMap(cnf) - return + + return tmpData, nil } -// ConfigContainer A Config represents the yaml configuration. -type ConfigContainer struct { - data map[string]interface{} - sync.RWMutex +func (*ConfigContainer) OnChange(_ string, _ func(value string)) { + // do nothing + logs.Warn("Unsupported operation: OnChange") } // Bool returns the boolean value for a given key. @@ -132,11 +149,11 @@ func (c *ConfigContainer) Bool(key string) (bool, error) { } // DefaultBool return the bool value if has no error -// otherwise return the defaultval -func (c *ConfigContainer) DefaultBool(key string, defaultval bool) bool { +// otherwise return the defaultVal +func (c *ConfigContainer) DefaultBool(key string, defaultVal bool) bool { v, err := c.Bool(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -154,31 +171,37 @@ func (c *ConfigContainer) Int(key string) (int, error) { } // DefaultInt returns the integer value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt(key string, defaultval int) int { +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt(key string, defaultVal int) int { v, err := c.Int(key) if err != nil { - return defaultval + return defaultVal } return v } // Int64 returns the int64 value for a given key. func (c *ConfigContainer) Int64(key string) (int64, error) { - if v, err := c.getData(key); err != nil { + v, err := c.getData(key) + if err != nil { return 0, err - } else if vv, ok := v.(int64); ok { - return vv, nil } - return 0, errors.New("not bool value") + switch val := v.(type) { + case int: + return int64(val), nil + case int64: + return val, nil + default: + return 0, errors.New("not int or int64 value") + } } // DefaultInt64 returns the int64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultInt64(key string, defaultval int64) int64 { +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultInt64(key string, defaultVal int64) int64 { v, err := c.Int64(key) if err != nil { - return defaultval + return defaultVal } return v } @@ -198,59 +221,69 @@ func (c *ConfigContainer) Float(key string) (float64, error) { } // DefaultFloat returns the float64 value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultFloat(key string, defaultval float64) float64 { +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultFloat(key string, defaultVal float64) float64 { v, err := c.Float(key) if err != nil { - return defaultval + return defaultVal } return v } // String returns the string value for a given key. -func (c *ConfigContainer) String(key string) string { +func (c *ConfigContainer) String(key string) (string, error) { if v, err := c.getData(key); err == nil { if vv, ok := v.(string); ok { - return vv + return vv, nil } } - return "" + return "", nil } // DefaultString returns the string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultString(key string, defaultval string) string { - v := c.String(key) - if v == "" { - return defaultval +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultString(key string, defaultVal string) string { + v, err := c.String(key) + if v == "" || err != nil { + return defaultVal } return v } // Strings returns the []string value for a given key. -func (c *ConfigContainer) Strings(key string) []string { - v := c.String(key) - if v == "" { - return nil +func (c *ConfigContainer) Strings(key string) ([]string, error) { + v, err := c.String(key) + if v == "" || err != nil { + return nil, err } - return strings.Split(v, ";") + return strings.Split(v, ";"), nil } // DefaultStrings returns the []string value for a given key. -// if err != nil return defaultval -func (c *ConfigContainer) DefaultStrings(key string, defaultval []string) []string { - v := c.Strings(key) - if v == nil { - return defaultval +// if err != nil return defaultVal +func (c *ConfigContainer) DefaultStrings(key string, defaultVal []string) []string { + v, err := c.Strings(key) + if v == nil || err != nil { + return defaultVal } return v } // GetSection returns map for the given section func (c *ConfigContainer) GetSection(section string) (map[string]string, error) { - if v, ok := c.data[section]; ok { - return v.(map[string]string), nil + switch val := v.(type) { + case map[string]interface{}: + res := make(map[string]string, len(val)) + for k2, v2 := range val { + res[k2] = fmt.Sprintf("%v", v2) + } + return res, nil + case map[string]string: + return val, nil + default: + return nil, fmt.Errorf("unexpected type: %v", v) + } } return nil, errors.New("not exist section") } @@ -263,7 +296,11 @@ func (c *ConfigContainer) SaveConfigFile(filename string) (err error) { return err } defer f.Close() - err = goyaml2.Write(f, c.data) + buf, err := yaml.Marshal(c.data) + if err != nil { + return err + } + _, err = f.Write(buf) return err } @@ -281,8 +318,7 @@ func (c *ConfigContainer) DIY(key string) (v interface{}, err error) { } func (c *ConfigContainer) getData(key string) (interface{}, error) { - - if len(key) == 0 { + if key == "" { return nil, errors.New("key is empty") } c.RLock() @@ -292,19 +328,14 @@ func (c *ConfigContainer) getData(key string) (interface{}, error) { tmpData := c.data for idx, k := range keys { if v, ok := tmpData[k]; ok { - switch v.(type) { + switch val := v.(type) { case map[string]interface{}: - { - tmpData = v.(map[string]interface{}) - if idx == len(keys) - 1 { - return tmpData, nil - } + tmpData = val + if idx == len(keys)-1 { + return tmpData, nil } default: - { - return v, nil - } - + return v, nil } } } diff --git a/core/config/yaml/yaml_test.go b/core/config/yaml/yaml_test.go new file mode 100644 index 0000000000..0430962e41 --- /dev/null +++ b/core/config/yaml/yaml_test.go @@ -0,0 +1,202 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package yaml + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/config" +) + +func TestYaml(t *testing.T) { + var ( + yamlcontext = ` +"appname": beeapi +"httpport": 8080 +"mysqlport": 3600 +"PI": 3.1415976 +"runmode": dev +"autorender": false +"copyrequestbody": true +"PATH": GOPATH +"path1": ${GOPATH} +"path2": ${GOPATH||/home/go} +"empty": "" +"user": + "name": "tom" + "age": 13 +` + + keyValue = map[string]interface{}{ + "appname": "beeapi", + "httpport": 8080, + "mysqlport": int64(3600), + "PI": 3.1415976, + "runmode": "dev", + "autorender": false, + "copyrequestbody": true, + "PATH": "GOPATH", + "path1": os.Getenv("GOPATH"), + "path2": os.Getenv("GOPATH"), + "error": "", + "emptystrings": []string{}, + } + ) + + f, err := os.Create("testyaml.conf") + if err != nil { + t.Fatal(err) + } + _, err = f.WriteString(yamlcontext) + if err != nil { + f.Close() + t.Fatal(err) + } + f.Close() + defer os.Remove("testyaml.conf") + + yamlconf, err := config.NewConfig("yaml", "testyaml.conf") + if err != nil { + t.Fatal(err) + } + + m, err := ReadYmlReader("testyaml.conf") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, m, yamlconf.(*ConfigContainer).data) + + shadow, err := (&Config{}).ParseData([]byte(yamlcontext)) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, shadow, yamlconf) + + yamlconf.OnChange("abc", func(value string) { + fmt.Printf("on change, value is %s \n", value) + }) + + res, _ := yamlconf.String("appname") + if res != "beeapi" { + t.Fatal("appname not equal to beeapi") + } + + for k, v := range keyValue { + + var ( + value interface{} + err error + ) + + switch v.(type) { + case int: + value, err = yamlconf.Int(k) + case int64: + value, err = yamlconf.Int64(k) + case float64: + value, err = yamlconf.Float(k) + case bool: + value, err = yamlconf.Bool(k) + case []string: + value, err = yamlconf.Strings(k) + case string: + value, err = yamlconf.String(k) + default: + value, err = yamlconf.DIY(k) + } + if err != nil { + t.Errorf("get key %q value fatal, %v err %s", k, v, err) + } else if fmt.Sprintf("%v", v) != fmt.Sprintf("%v", value) { + t.Errorf("get key %q value, want %v got %v .", k, v, value) + } + + } + + if err = yamlconf.Set("name", "astaxie"); err != nil { + t.Fatal(err) + } + res, _ = yamlconf.String("name") + if res != "astaxie" { + t.Fatal("get name error") + } + + sub, err := yamlconf.Sub("user") + if err != nil { + t.Fatal(err) + } + assert.NotNil(t, sub) + name, err := sub.String("name") + assert.Nil(t, err) + assert.Equal(t, "tom", name) + + age, err := sub.Int("age") + assert.Nil(t, err) + assert.Equal(t, 13, age) + + user := &User{} + + err = sub.Unmarshaler("", user) + assert.Nil(t, err) + assert.Equal(t, "tom", user.Name) + assert.Equal(t, 13, user.Age) + + user = &User{} + + err = yamlconf.Unmarshaler("user", user) + assert.Nil(t, err) + assert.Equal(t, "tom", user.Name) + assert.Equal(t, 13, user.Age) + + // default value + assert.Equal(t, "beeapi", yamlconf.DefaultString("appname", "invalid")) + assert.Equal(t, "invalid", yamlconf.DefaultString("i-appname", "invalid")) + assert.Equal(t, 8080, yamlconf.DefaultInt("httpport", 8090)) + assert.Equal(t, 8090, yamlconf.DefaultInt("i-httpport", 8090)) + assert.Equal(t, 3.1415976, yamlconf.DefaultFloat("PI", 3.14)) + assert.Equal(t, 3.14, yamlconf.DefaultFloat("1-PI", 3.14)) + assert.True(t, yamlconf.DefaultBool("copyrequestbody", false)) + assert.True(t, yamlconf.DefaultBool("i-copyrequestbody", true)) + assert.Equal(t, int64(8080), yamlconf.DefaultInt64("httpport", 8090)) + assert.Equal(t, int64(8090), yamlconf.DefaultInt64("i-httpport", 8090)) + assert.Equal(t, "tom", yamlconf.DefaultString("user.name", "invalid")) + assert.Equal(t, "invalid", yamlconf.DefaultString("user.1-name", "invalid")) + assert.Equal(t, []string{"tom"}, yamlconf.DefaultStrings("strings", []string{"tom"})) + + appName, err := yamlconf.DIY("appname") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "beeapi", appName) + + err = yamlconf.SaveConfigFile(f.Name()) + if err != nil { + t.Fatal(err) + } + + section, err := yamlconf.GetSection("user") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "tom", section["name"]) +} + +type User struct { + Name string `yaml:"name"` + Age int `yaml:"age"` +} diff --git a/logs/README.md b/core/logs/README.md similarity index 68% rename from logs/README.md rename to core/logs/README.md index c05bcc0444..5d6ad665fd 100644 --- a/logs/README.md +++ b/core/logs/README.md @@ -1,24 +1,22 @@ ## logs -logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` . +logs is a Go logs manager. It can use many logs adapters. The repo is inspired by `database/sql` . ## How to install? - go get github.com/astaxie/beego/logs - + go get github.com/beego/beego/v2/core/logs ## What adapters are supported? As of now this logs support console, file,smtp and conn. - ## How to use it? First you must import it ```golang import ( - "github.com/astaxie/beego/logs" + "github.com/beego/beego/v2/core/logs" ) ``` @@ -70,3 +68,24 @@ log.SetLogger("smtp", `{"username":"beegotest@gmail.com","password":"xxxxxxxx"," log.Critical("sendmail critical") time.Sleep(time.Second * 30) ``` + +## AccessLog in beego + +Current support three format of accesslog: +- apache format +- json format +- custom formt + +### how to using custom format in beego + +``` +// define the custom format function +log.CustomAccessLogFunc = func(r *beelog.AccessLogRecord) string { + return fmt.Sprintf("custom access log %s", r.Request) +} + +// open access log and set format in beego config. +Config.Log.AccessLogs = true +Config.Log.AccessLogsFormat = "JSON_FORMAT" + +``` \ No newline at end of file diff --git a/logs/accesslog.go b/core/logs/access_log.go similarity index 82% rename from logs/accesslog.go rename to core/logs/access_log.go index 3ff9e20fc8..2f3395a7d9 100644 --- a/logs/accesslog.go +++ b/core/logs/access_log.go @@ -16,9 +16,9 @@ package logs import ( "bytes" - "strings" "encoding/json" "fmt" + "strings" "time" ) @@ -26,9 +26,12 @@ const ( apacheFormatPattern = "%s - - [%s] \"%s %d %d\" %f %s %s" apacheFormat = "APACHE_FORMAT" jsonFormat = "JSON_FORMAT" + customFormat = "CUSTOM_FORMAT" ) -// AccessLogRecord struct for holding access log data. +var CustomAccessLogFunc func(r *AccessLogRecord) string + +// AccessLogRecord is astruct for holding access log data. type AccessLogRecord struct { RemoteAddr string `json:"remote_addr"` RequestTime time.Time `json:"request_time"` @@ -63,12 +66,27 @@ func disableEscapeHTML(i interface{}) { // AccessLog - Format and print access log. func AccessLog(r *AccessLogRecord, format string) { - var msg string + msg := r.format(format) + lm := &LogMsg{ + Msg: strings.TrimSpace(msg), + When: time.Now(), + Level: levelLoggerImpl, + } + beeLogger.writeMsg(lm) +} + +func (r *AccessLogRecord) format(format string) string { + msg := "" switch format { case apacheFormat: timeFormatted := r.RequestTime.Format("02/Jan/2006 03:04:05") msg = fmt.Sprintf(apacheFormatPattern, r.RemoteAddr, timeFormatted, r.Request, r.Status, r.BodyBytesSent, r.ElapsedTime.Seconds(), r.HTTPReferrer, r.HTTPUserAgent) + case customFormat: + if CustomAccessLogFunc != nil { + return CustomAccessLogFunc(r) + } + fallthrough case jsonFormat: fallthrough default: @@ -79,5 +97,5 @@ func AccessLog(r *AccessLogRecord, format string) { msg = string(jsonData) } } - beeLogger.writeMsg(levelLoggerImpl, strings.TrimSpace(msg)) + return msg } diff --git a/core/logs/access_log_test.go b/core/logs/access_log_test.go new file mode 100644 index 0000000000..f78a00a03d --- /dev/null +++ b/core/logs/access_log_test.go @@ -0,0 +1,38 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAccessLog_format(t *testing.T) { + alc := &AccessLogRecord{ + RequestTime: time.Date(2020, 9, 19, 21, 21, 21, 11, time.UTC), + } + + res := alc.format(apacheFormat) + println(res) + assert.Equal(t, " - - [19/Sep/2020 09:21:21] \" 0 0\" 0.000000 ", res) + + res = alc.format(jsonFormat) + assert.Equal(t, + "{\"remote_addr\":\"\",\"request_time\":\"2020-09-19T21:21:21.000000011Z\",\"request_method\":\"\",\"request\":\"\",\"server_protocol\":\"\",\"host\":\"\",\"status\":0,\"body_bytes_sent\":0,\"elapsed_time\":0,\"http_referrer\":\"\",\"http_user_agent\":\"\",\"remote_user\":\"\"}\n", res) + + AccessLog(alc, jsonFormat) +} diff --git a/logs/alils/alils.go b/core/logs/alils/alils.go similarity index 70% rename from logs/alils/alils.go rename to core/logs/alils/alils.go index 867ff4cb53..4e0d3089ff 100644 --- a/logs/alils/alils.go +++ b/core/logs/alils/alils.go @@ -2,18 +2,19 @@ package alils import ( "encoding/json" + "fmt" "strings" "sync" - "time" - "github.com/astaxie/beego/logs" "github.com/gogo/protobuf/proto" + + "github.com/beego/beego/v2/core/logs" ) const ( - // CacheSize set the flush size + // CacheSize sets the flush size CacheSize int = 64 - // Delimiter define the topic delimiter + // Delimiter defines the topic delimiter Delimiter string = "##" ) @@ -28,10 +29,11 @@ type Config struct { Source string `json:"source"` Level int `json:"level"` FlushWhen int `json:"flush_when"` + Formatter string `json:"formatter"` } // aliLSWriter implements LoggerInterface. -// it writes messages in keep-live tcp connection. +// Writes messages in keep-live tcp connection. type aliLSWriter struct { store *LogStore group []*LogGroup @@ -39,19 +41,23 @@ type aliLSWriter struct { groupMap map[string]*LogGroup lock *sync.Mutex Config + formatter logs.LogFormatter } -// NewAliLS create a new Logger +// NewAliLS creates a new Logger func NewAliLS() logs.Logger { alils := new(aliLSWriter) alils.Level = logs.LevelTrace + alils.formatter = alils return alils } -// Init parse config and init struct -func (c *aliLSWriter) Init(jsonConfig string) (err error) { - - json.Unmarshal([]byte(jsonConfig), c) +// Init parses config and initializes struct +func (c *aliLSWriter) Init(config string) error { + err := json.Unmarshal([]byte(config), c) + if err != nil { + return err + } if c.FlushWhen > CacheSize { c.FlushWhen = CacheSize @@ -64,11 +70,13 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { AccessKeySecret: c.KeySecret, } - c.store, err = prj.GetLogStore(c.LogStore) + store, err := prj.GetLogStore(c.LogStore) if err != nil { return err } + c.store = store + // Create default Log Group c.group = append(c.group, &LogGroup{ Topic: proto.String(""), @@ -98,14 +106,29 @@ func (c *aliLSWriter) Init(jsonConfig string) (err error) { c.lock = &sync.Mutex{} + if len(c.Formatter) > 0 { + fmtr, ok := logs.GetFormatter(c.Formatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", c.Formatter) + } + c.formatter = fmtr + } + return nil } -// WriteMsg write message in connection. -// if connection is down, try to re-connect. -func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) { +func (c *aliLSWriter) Format(lm *logs.LogMsg) string { + return lm.OldStyleFormat() +} + +func (c *aliLSWriter) SetFormatter(f logs.LogFormatter) { + c.formatter = f +} - if level > c.Level { +// WriteMsg writes a message in connection. +// If connection is down, try to re-connect. +func (c *aliLSWriter) WriteMsg(lm *logs.LogMsg) error { + if lm.Level > c.Level { return nil } @@ -115,31 +138,30 @@ func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error if c.withMap { // Topic,LogGroup - strs := strings.SplitN(msg, Delimiter, 2) + strs := strings.SplitN(lm.Msg, Delimiter, 2) if len(strs) == 2 { pos := strings.LastIndex(strs[0], " ") topic = strs[0][pos+1 : len(strs[0])] - content = strs[0][0:pos] + strs[1] lg = c.groupMap[topic] } // send to empty Topic if lg == nil { - content = msg lg = c.group[0] } } else { - content = msg lg = c.group[0] } + content = c.formatter.Format(lm) + c1 := &LogContent{ Key: proto.String("msg"), Value: proto.String(content), } l := &Log{ - Time: proto.Uint32(uint32(when.Unix())), + Time: proto.Uint32(uint32(lm.When.Unix())), Contents: []*LogContent{ c1, }, @@ -152,13 +174,11 @@ func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error if len(lg.Logs) >= c.FlushWhen { c.flush(lg) } - return nil } // Flush implementing method. empty. func (c *aliLSWriter) Flush() { - // flush all group for _, lg := range c.group { c.flush(lg) @@ -170,7 +190,6 @@ func (c *aliLSWriter) Destroy() { } func (c *aliLSWriter) flush(lg *LogGroup) { - c.lock.Lock() defer c.lock.Unlock() err := c.store.PutLogs(lg) diff --git a/logs/alils/config.go b/core/logs/alils/config.go similarity index 61% rename from logs/alils/config.go rename to core/logs/alils/config.go index e8c24448fc..558db6a820 100755 --- a/logs/alils/config.go +++ b/core/logs/alils/config.go @@ -4,10 +4,10 @@ const ( version = "0.5.0" // SDK version signatureMethod = "hmac-sha1" // Signature method - // OffsetNewest stands for the log head offset, i.e. the offset that will be + // OffsetNewest is the log head offset, i.e. the offset that will be // assigned to the next message that will be produced to the shard. OffsetNewest = "end" - // OffsetOldest stands for the oldest offset available on the logstore for a + // OffsetOldest is the oldest offset available on the logstore for a // shard. OffsetOldest = "begin" ) diff --git a/logs/alils/log.pb.go b/core/logs/alils/log.pb.go similarity index 93% rename from logs/alils/log.pb.go rename to core/logs/alils/log.pb.go index 601b0d78d3..f275c21805 100755 --- a/logs/alils/log.pb.go +++ b/core/logs/alils/log.pb.go @@ -5,13 +5,14 @@ import ( "io" "math" - "github.com/gogo/protobuf/proto" github_com_gogo_protobuf_proto "github.com/gogo/protobuf/proto" ) // Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal +var _ = github_com_gogo_protobuf_proto.Marshal + var _ = fmt.Errorf + var _ = math.Inf var ( @@ -31,13 +32,13 @@ type Log struct { // Reset the Log func (m *Log) Reset() { *m = Log{} } -// String return the Compact Log -func (m *Log) String() string { return proto.CompactTextString(m) } +// String returns the Compact Log +func (m *Log) String() string { return github_com_gogo_protobuf_proto.CompactTextString(m) } // ProtoMessage not implemented func (*Log) ProtoMessage() {} -// GetTime return the Log's Time +// GetTime returns the Log's Time func (m *Log) GetTime() uint32 { if m != nil && m.Time != nil { return *m.Time @@ -45,7 +46,7 @@ func (m *Log) GetTime() uint32 { return 0 } -// GetContents return the Log's Contents +// GetContents returns the Log's Contents func (m *Log) GetContents() []*LogContent { if m != nil { return m.Contents @@ -53,7 +54,7 @@ func (m *Log) GetContents() []*LogContent { return nil } -// LogContent define the Log content struct +// LogContent defines the Log content struct type LogContent struct { Key *string `protobuf:"bytes,1,req,name=Key" json:"Key,omitempty"` Value *string `protobuf:"bytes,2,req,name=Value" json:"Value,omitempty"` @@ -63,13 +64,13 @@ type LogContent struct { // Reset LogContent func (m *LogContent) Reset() { *m = LogContent{} } -// String return the compact text -func (m *LogContent) String() string { return proto.CompactTextString(m) } +// String returns the compact text +func (m *LogContent) String() string { return github_com_gogo_protobuf_proto.CompactTextString(m) } // ProtoMessage not implemented func (*LogContent) ProtoMessage() {} -// GetKey return the Key +// GetKey returns the key func (m *LogContent) GetKey() string { if m != nil && m.Key != nil { return *m.Key @@ -77,7 +78,7 @@ func (m *LogContent) GetKey() string { return "" } -// GetValue return the Value +// GetValue returns the value func (m *LogContent) GetValue() string { if m != nil && m.Value != nil { return *m.Value @@ -85,7 +86,7 @@ func (m *LogContent) GetValue() string { return "" } -// LogGroup define the logs struct +// LogGroup defines the logs struct type LogGroup struct { Logs []*Log `protobuf:"bytes,1,rep,name=Logs" json:"Logs,omitempty"` Reserved *string `protobuf:"bytes,2,opt,name=Reserved" json:"Reserved,omitempty"` @@ -97,13 +98,13 @@ type LogGroup struct { // Reset LogGroup func (m *LogGroup) Reset() { *m = LogGroup{} } -// String return the compact text -func (m *LogGroup) String() string { return proto.CompactTextString(m) } +// String returns the compact text +func (m *LogGroup) String() string { return github_com_gogo_protobuf_proto.CompactTextString(m) } // ProtoMessage not implemented func (*LogGroup) ProtoMessage() {} -// GetLogs return the loggroup logs +// GetLogs returns the loggroup logs func (m *LogGroup) GetLogs() []*Log { if m != nil { return m.Logs @@ -111,7 +112,8 @@ func (m *LogGroup) GetLogs() []*Log { return nil } -// GetReserved return Reserved +// GetReserved returns Reserved. An empty string is returned +// if an error occurs func (m *LogGroup) GetReserved() string { if m != nil && m.Reserved != nil { return *m.Reserved @@ -119,7 +121,8 @@ func (m *LogGroup) GetReserved() string { return "" } -// GetTopic return Topic +// GetTopic returns Topic. An empty string is returned +// if an error occurs func (m *LogGroup) GetTopic() string { if m != nil && m.Topic != nil { return *m.Topic @@ -127,7 +130,8 @@ func (m *LogGroup) GetTopic() string { return "" } -// GetSource return Source +// GetSource returns source. An empty string is returned +// if an error occurs func (m *LogGroup) GetSource() string { if m != nil && m.Source != nil { return *m.Source @@ -135,7 +139,7 @@ func (m *LogGroup) GetSource() string { return "" } -// LogGroupList define the LogGroups +// LogGroupList defines the LogGroups type LogGroupList struct { LogGroups []*LogGroup `protobuf:"bytes,1,rep,name=logGroups" json:"logGroups,omitempty"` XXXUnrecognized []byte `json:"-"` @@ -144,13 +148,13 @@ type LogGroupList struct { // Reset LogGroupList func (m *LogGroupList) Reset() { *m = LogGroupList{} } -// String return compact text -func (m *LogGroupList) String() string { return proto.CompactTextString(m) } +// String returns compact text +func (m *LogGroupList) String() string { return github_com_gogo_protobuf_proto.CompactTextString(m) } // ProtoMessage not implemented func (*LogGroupList) ProtoMessage() {} -// GetLogGroups return the LogGroups +// GetLogGroups returns the LogGroups func (m *LogGroupList) GetLogGroups() []*LogGroup { if m != nil { return m.LogGroups @@ -158,7 +162,7 @@ func (m *LogGroupList) GetLogGroups() []*LogGroup { return nil } -// Marshal the logs to byte slice +// Marshal marshals the logs to byte slice func (m *Log) Marshal() (data []byte, err error) { size := m.Size() data = make([]byte, size) @@ -336,6 +340,7 @@ func encodeFixed64Log(data []byte, offset int, v uint64) int { data[offset+7] = uint8(v >> 56) return offset + 8 } + func encodeFixed32Log(data []byte, offset int, v uint32) int { data[offset] = uint8(v) data[offset+1] = uint8(v >> 8) @@ -343,6 +348,7 @@ func encodeFixed32Log(data []byte, offset int, v uint32) int { data[offset+3] = uint8(v >> 24) return offset + 4 } + func encodeVarintLog(data []byte, offset int, v uint64) int { for v >= 1<<7 { data[offset] = uint8(v&0x7f | 0x80) @@ -353,7 +359,7 @@ func encodeVarintLog(data []byte, offset int, v uint64) int { return offset + 1 } -// Size return the log's size +// Size returns the log's size func (m *Log) Size() (n int) { var l int _ = l @@ -372,7 +378,7 @@ func (m *Log) Size() (n int) { return n } -// Size return LogContent size based on Key and Value +// Size returns LogContent size based on Key and Value func (m *LogContent) Size() (n int) { var l int _ = l @@ -390,7 +396,7 @@ func (m *LogContent) Size() (n int) { return n } -// Size return LogGroup size based on Logs +// Size returns LogGroup size based on Logs func (m *LogGroup) Size() (n int) { var l int _ = l @@ -418,7 +424,7 @@ func (m *LogGroup) Size() (n int) { return n } -// Size return LogGroupList size +// Size returns LogGroupList size func (m *LogGroupList) Size() (n int) { var l int _ = l @@ -444,11 +450,12 @@ func sovLog(x uint64) (n int) { } return n } + func sozLog(x uint64) (n int) { return sovLog((x << 1) ^ (x >> 63)) } -// Unmarshal data to log +// Unmarshal unmarshals data to log func (m *Log) Unmarshal(data []byte) error { var hasFields [1]uint64 l := len(data) @@ -557,7 +564,7 @@ func (m *Log) Unmarshal(data []byte) error { return nil } -// Unmarshal data to LogContent +// Unmarshal unmarshals data to LogContent func (m *LogContent) Unmarshal(data []byte) error { var hasFields [1]uint64 l := len(data) @@ -679,7 +686,7 @@ func (m *LogContent) Unmarshal(data []byte) error { return nil } -// Unmarshal data to LogGroup +// Unmarshal unmarshals data to LogGroup func (m *LogGroup) Unmarshal(data []byte) error { l := len(data) iNdEx := 0 @@ -853,7 +860,7 @@ func (m *LogGroup) Unmarshal(data []byte) error { return nil } -// Unmarshal data to LogGroupList +// Unmarshal unmarshals data to LogGroupList func (m *LogGroupList) Unmarshal(data []byte) error { l := len(data) iNdEx := 0 diff --git a/logs/alils/log_config.go b/core/logs/alils/log_config.go similarity index 91% rename from logs/alils/log_config.go rename to core/logs/alils/log_config.go index e8564efbd0..7daeb8648b 100755 --- a/logs/alils/log_config.go +++ b/core/logs/alils/log_config.go @@ -1,6 +1,6 @@ package alils -// InputDetail define log detail +// InputDetail defines log detail type InputDetail struct { LogType string `json:"logType"` LogPath string `json:"logPath"` @@ -15,13 +15,13 @@ type InputDetail struct { TopicFormat string `json:"topicFormat"` } -// OutputDetail define the output detail +// OutputDetail defines the output detail type OutputDetail struct { Endpoint string `json:"endpoint"` LogStoreName string `json:"logstoreName"` } -// LogConfig define Log Config +// LogConfig defines Log Config type LogConfig struct { Name string `json:"configName"` InputType string `json:"inputType"` diff --git a/logs/alils/log_project.go b/core/logs/alils/log_project.go similarity index 95% rename from logs/alils/log_project.go rename to core/logs/alils/log_project.go index 59db8cbf78..700b50c7bb 100755 --- a/logs/alils/log_project.go +++ b/core/logs/alils/log_project.go @@ -9,7 +9,7 @@ package alils import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httputil" ) @@ -20,7 +20,7 @@ type errorMessage struct { Message string `json:"errorMessage"` } -// LogProject Define the Ali Project detail +// LogProject defines the Ali Project detail type LogProject struct { Name string // Project name Endpoint string // IP or hostname of SLS endpoint @@ -45,13 +45,13 @@ func (p *LogProject) ListLogStore() (storeNames []string, err error) { "x-sls-bodyrawsize": "0", } - uri := fmt.Sprintf("/logstores") + uri := "/logstores" r, err := request(p, "GET", uri, h, nil) if err != nil { return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -96,7 +96,7 @@ func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -128,7 +128,6 @@ func (p *LogProject) GetLogStore(name string) (s *LogStore, err error) { // and ttl is time-to-live(in day) of logs, // and shardCnt is the number of shards. func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) { - type Body struct { Name string `json:"logstoreName"` TTL int `json:"ttl"` @@ -157,7 +156,7 @@ func (p *LogProject) CreateLogStore(name string, ttl, shardCnt int) (err error) return } - body, err = ioutil.ReadAll(r.Body) + body, err = io.ReadAll(r.Body) if err != nil { return } @@ -189,7 +188,7 @@ func (p *LogProject) DeleteLogStore(name string) (err error) { return } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { return } @@ -212,7 +211,6 @@ func (p *LogProject) DeleteLogStore(name string) (err error) { // UpdateLogStore updates a logstore according by logstore name, // obviously we can't modify the logstore name itself. func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) { - type Body struct { Name string `json:"logstoreName"` TTL int `json:"ttl"` @@ -241,7 +239,7 @@ func (p *LogProject) UpdateLogStore(name string, ttl, shardCnt int) (err error) return } - body, err = ioutil.ReadAll(r.Body) + body, err = io.ReadAll(r.Body) if err != nil { return } @@ -279,7 +277,7 @@ func (p *LogProject) ListMachineGroup(offset, size int) (m []string, total int, return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -326,7 +324,7 @@ func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -355,7 +353,6 @@ func (p *LogProject) GetMachineGroup(name string) (m *MachineGroup, err error) { // CreateMachineGroup creates a new machine group in SLS. func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { - body, err := json.Marshal(m) if err != nil { return @@ -372,7 +369,7 @@ func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { return } - body, err = ioutil.ReadAll(r.Body) + body, err = io.ReadAll(r.Body) if err != nil { return } @@ -395,7 +392,6 @@ func (p *LogProject) CreateMachineGroup(m *MachineGroup) (err error) { // UpdateMachineGroup updates a machine group. func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { - body, err := json.Marshal(m) if err != nil { return @@ -412,7 +408,7 @@ func (p *LogProject) UpdateMachineGroup(m *MachineGroup) (err error) { return } - body, err = ioutil.ReadAll(r.Body) + body, err = io.ReadAll(r.Body) if err != nil { return } @@ -444,7 +440,7 @@ func (p *LogProject) DeleteMachineGroup(name string) (err error) { return } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { return } @@ -481,7 +477,7 @@ func (p *LogProject) ListConfig(offset, size int) (cfgNames []string, total int, return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -526,7 +522,7 @@ func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -555,7 +551,6 @@ func (p *LogProject) GetConfig(name string) (c *LogConfig, err error) { // UpdateConfig updates a config. func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { - body, err := json.Marshal(c) if err != nil { return @@ -572,7 +567,7 @@ func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { return } - body, err = ioutil.ReadAll(r.Body) + body, err = io.ReadAll(r.Body) if err != nil { return } @@ -595,7 +590,6 @@ func (p *LogProject) UpdateConfig(c *LogConfig) (err error) { // CreateConfig creates a new config in SLS. func (p *LogProject) CreateConfig(c *LogConfig) (err error) { - body, err := json.Marshal(c) if err != nil { return @@ -612,7 +606,7 @@ func (p *LogProject) CreateConfig(c *LogConfig) (err error) { return } - body, err = ioutil.ReadAll(r.Body) + body, err = io.ReadAll(r.Body) if err != nil { return } @@ -644,7 +638,7 @@ func (p *LogProject) DeleteConfig(name string) (err error) { return } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { return } @@ -676,7 +670,7 @@ func (p *LogProject) GetAppliedMachineGroups(confName string) (groupNames []stri return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -721,7 +715,7 @@ func (p *LogProject) GetAppliedConfigs(groupName string) (confNames []string, er return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -766,7 +760,7 @@ func (p *LogProject) ApplyConfigToMachineGroup(confName, groupName string) (err return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -798,7 +792,7 @@ func (p *LogProject) RemoveConfigFromMachineGroup(confName, groupName string) (e return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } diff --git a/logs/alils/log_store.go b/core/logs/alils/log_store.go similarity index 94% rename from logs/alils/log_store.go rename to core/logs/alils/log_store.go index fa50273646..e699b83203 100755 --- a/logs/alils/log_store.go +++ b/core/logs/alils/log_store.go @@ -3,7 +3,7 @@ package alils import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httputil" "strconv" @@ -12,7 +12,7 @@ import ( "github.com/gogo/protobuf/proto" ) -// LogStore Store the logs +// LogStore stores the logs type LogStore struct { Name string `json:"logstoreName"` TTL int @@ -24,7 +24,7 @@ type LogStore struct { project *LogProject } -// Shard define the Log Shard +// Shard defines the Log Shard type Shard struct { ShardID int `json:"shardID"` } @@ -41,7 +41,7 @@ func (s *LogStore) ListShards() (shardIDs []int, err error) { return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -71,7 +71,7 @@ func (s *LogStore) ListShards() (shardIDs []int, err error) { return } -// PutLogs put logs into logstore. +// PutLogs puts logs into logstore. // The callers should transform user logs into LogGroup. func (s *LogStore) PutLogs(lg *LogGroup) (err error) { body, err := proto.Marshal(lg) @@ -79,7 +79,7 @@ func (s *LogStore) PutLogs(lg *LogGroup) (err error) { return } - // Compresse body with lz4 + // Compress body with lz4 out := make([]byte, lz4.CompressBound(body)) n, err := lz4.Compress(body, out) if err != nil { @@ -98,7 +98,7 @@ func (s *LogStore) PutLogs(lg *LogGroup) (err error) { return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -134,7 +134,7 @@ func (s *LogStore) GetCursor(shardID int, from string) (cursor string, err error return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -185,7 +185,7 @@ func (s *LogStore) GetLogsBytes(shardID int, cursor string, return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } @@ -239,9 +239,8 @@ func (s *LogStore) GetLogsBytes(shardID int, cursor string, return } -// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API +// LogsBytesDecode decodes logs binary data returned by GetLogsBytes API func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) { - gl = &LogGroupList{} err = proto.Unmarshal(data, gl) if err != nil { diff --git a/logs/alils/machine_group.go b/core/logs/alils/machine_group.go similarity index 86% rename from logs/alils/machine_group.go rename to core/logs/alils/machine_group.go index b6c69a141a..b00c480e36 100755 --- a/logs/alils/machine_group.go +++ b/core/logs/alils/machine_group.go @@ -3,18 +3,18 @@ package alils import ( "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httputil" ) -// MachineGroupAttribute define the Attribute +// MachineGroupAttribute defines the Attribute type MachineGroupAttribute struct { ExternalName string `json:"externalName"` TopicName string `json:"groupTopic"` } -// MachineGroup define the machine Group +// MachineGroup defines the machine Group type MachineGroup struct { Name string `json:"groupName"` Type string `json:"groupType"` @@ -29,20 +29,20 @@ type MachineGroup struct { project *LogProject } -// Machine define the Machine +// Machine defines the Machine type Machine struct { IP string UniqueID string `json:"machine-uniqueid"` UserdefinedID string `json:"userdefined-id"` } -// MachineList define the Machine List +// MachineList defines the Machine List type MachineList struct { Total int Machines []*Machine } -// ListMachines returns machine list of this machine group. +// ListMachines returns the machine list of this machine group. func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { h := map[string]string{ "x-sls-bodyrawsize": "0", @@ -54,7 +54,7 @@ func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) { return } - buf, err := ioutil.ReadAll(r.Body) + buf, err := io.ReadAll(r.Body) if err != nil { return } diff --git a/logs/alils/request.go b/core/logs/alils/request.go similarity index 92% rename from logs/alils/request.go rename to core/logs/alils/request.go index 50d9c43c56..dce4dccde3 100755 --- a/logs/alils/request.go +++ b/core/logs/alils/request.go @@ -13,7 +13,7 @@ func request(project *LogProject, method, uri string, headers map[string]string, // The caller should provide 'x-sls-bodyrawsize' header if _, ok := headers["x-sls-bodyrawsize"]; !ok { - err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header") + err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header") return } @@ -27,7 +27,7 @@ func request(project *LogProject, method, uri string, headers map[string]string, headers["Content-MD5"] = bodyMD5 if _, ok := headers["Content-Type"]; !ok { - err = fmt.Errorf("Can't find 'Content-Type' header") + err = fmt.Errorf("can't find 'Content-Type' header") return } } diff --git a/logs/alils/signature.go b/core/logs/alils/signature.go similarity index 100% rename from logs/alils/signature.go rename to core/logs/alils/signature.go diff --git a/logs/conn.go b/core/logs/conn.go similarity index 66% rename from logs/conn.go rename to core/logs/conn.go index afe0cbb75a..131c08ccd9 100644 --- a/logs/conn.go +++ b/core/logs/conn.go @@ -16,16 +16,18 @@ package logs import ( "encoding/json" + "fmt" "io" "net" - "time" ) // connWriter implements LoggerInterface. -// it writes messages in keep-live tcp connection. +// Writes messages in keep-live tcp connection. type connWriter struct { lg *logWriter innerWriter io.WriteCloser + formatter LogFormatter + Formatter string `json:"formatter"` ReconnectOnMsg bool `json:"reconnectOnMsg"` Reconnect bool `json:"reconnect"` Net string `json:"net"` @@ -33,23 +35,40 @@ type connWriter struct { Level int `json:"level"` } -// NewConn create new ConnWrite returning as LoggerInterface. +// NewConn creates new ConnWrite returning as LoggerInterface. func NewConn() Logger { conn := new(connWriter) conn.Level = LevelTrace + conn.formatter = conn return conn } -// Init init connection writer with json config. -// json config only need key "level". -func (c *connWriter) Init(jsonConfig string) error { - return json.Unmarshal([]byte(jsonConfig), c) +func (c *connWriter) Format(lm *LogMsg) string { + return lm.OldStyleFormat() } -// WriteMsg write message in connection. -// if connection is down, try to re-connect. -func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > c.Level { +// Init initializes a connection writer with json config. +// json config only needs they "level" key +func (c *connWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), c) + if res == nil && len(c.Formatter) > 0 { + fmtr, ok := GetFormatter(c.Formatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", c.Formatter) + } + c.formatter = fmtr + } + return res +} + +func (c *connWriter) SetFormatter(f LogFormatter) { + c.formatter = f +} + +// WriteMsg writes message in connection. +// If connection is down, try to re-connect. +func (c *connWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > c.Level { return nil } if c.needToConnectOnMsg() { @@ -63,13 +82,17 @@ func (c *connWriter) WriteMsg(when time.Time, msg string, level int) error { defer c.innerWriter.Close() } - c.lg.writeln(when, msg) + msg := c.formatter.Format(lm) + + _, err := c.lg.writeln(msg) + if err != nil { + return err + } return nil } // Flush implementing method. empty. func (c *connWriter) Flush() { - } // Destroy destroy connection writer and close tcp listener. @@ -101,7 +124,6 @@ func (c *connWriter) connect() error { func (c *connWriter) needToConnectOnMsg() bool { if c.Reconnect { - c.Reconnect = false return true } diff --git a/core/logs/conn_test.go b/core/logs/conn_test.go new file mode 100644 index 0000000000..bba6bb3895 --- /dev/null +++ b/core/logs/conn_test.go @@ -0,0 +1,96 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "net" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// ConnTCPListener takes a TCP listener and accepts n TCP connections +// Returns connections using connChan +func connTCPListener(t *testing.T, n int, ln net.Listener, connChan chan<- net.Conn) { + // Listen and accept n incoming connections + for i := 0; i < n; i++ { + conn, err := ln.Accept() + if err != nil { + t.Log("Error accepting connection: ", err.Error()) + os.Exit(1) + } + + // Send accepted connection to channel + connChan <- conn + } + ln.Close() + close(connChan) +} + +func TestConn(t *testing.T) { + log := NewLogger(1000) + log.SetLogger("conn", `{"net":"tcp","addr":":7020"}`) + log.Informational("informational") +} + +// need to rewrite this test, it's not stable +func TestReconnect(t *testing.T) { + // Setup connection listener + newConns := make(chan net.Conn) + connNum := 3 + ln, err := net.Listen("tcp", ":6002") + if err != nil { + t.Log("Error listening:", err.Error()) + os.Exit(1) + } + go connTCPListener(t, connNum, ln, newConns) + + // Setup logger + log := NewLogger(1000) + log.SetPrefix("test") + log.SetLogger(AdapterConn, `{"net":"tcp","reconnect":true,"level":6,"addr":":6002"}`) + log.Informational("informational 1") + + // Refuse first connection + first := <-newConns + first.Close() + + // Send another log after conn closed + log.Informational("informational 2") + + // Check if there was a second connection attempt + select { + case second := <-newConns: + second.Close() + default: + t.Error("Did not reconnect") + } +} + +func TestConnWriter_Format(t *testing.T) { + lg := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + cw := NewConn().(*connWriter) + res := cw.Format(lg) + assert.Equal(t, "[D] Cus Hello, world", res) +} diff --git a/logs/console.go b/core/logs/console.go similarity index 58% rename from logs/console.go rename to core/logs/console.go index 3dcaee1dfe..ccf6d53c82 100644 --- a/logs/console.go +++ b/core/logs/console.go @@ -16,9 +16,9 @@ package logs import ( "encoding/json" + "fmt" "os" "strings" - "time" "github.com/shiena/ansicolor" ) @@ -26,7 +26,7 @@ import ( // brush is a color join function type brush func(string) string -// newBrush return a fix color Brush +// newBrush returns a fix color Brush func newBrush(color string) brush { pre := "\033[" reset := "\033[0m" @@ -48,50 +48,75 @@ var colors = []brush{ // consoleWriter implements LoggerInterface and writes messages to terminal. type consoleWriter struct { - lg *logWriter - Level int `json:"level"` - Colorful bool `json:"color"` //this filed is useful only when system's terminal supports color + lg *logWriter + formatter LogFormatter + Formatter string `json:"formatter"` + Level int `json:"level"` + Colorful bool `json:"color"` // this filed is useful only when system's terminal supports color } -// NewConsole create ConsoleWriter returning as LoggerInterface. +func (c *consoleWriter) Format(lm *LogMsg) string { + msg := lm.OldStyleFormat() + if c.Colorful { + msg = strings.Replace(msg, levelPrefix[lm.Level], colors[lm.Level](levelPrefix[lm.Level]), 1) + } + h, _, _ := formatTimeHeader(lm.When) + return string(append(h, msg...)) +} + +func (c *consoleWriter) SetFormatter(f LogFormatter) { + c.formatter = f +} + +// NewConsole creates ConsoleWriter returning as LoggerInterface. func NewConsole() Logger { + return newConsole() +} + +func newConsole() *consoleWriter { cw := &consoleWriter{ lg: newLogWriter(ansicolor.NewAnsiColorWriter(os.Stdout)), Level: LevelDebug, Colorful: true, } + cw.formatter = cw return cw } -// Init init console logger. -// jsonConfig like '{"level":LevelTrace}'. -func (c *consoleWriter) Init(jsonConfig string) error { - if len(jsonConfig) == 0 { +// Init initianlizes the console logger. +// jsonConfig must be in the format '{"level":LevelTrace}' +func (c *consoleWriter) Init(config string) error { + if len(config) == 0 { return nil } - return json.Unmarshal([]byte(jsonConfig), c) + + res := json.Unmarshal([]byte(config), c) + if res == nil && len(c.Formatter) > 0 { + fmtr, ok := GetFormatter(c.Formatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", c.Formatter) + } + c.formatter = fmtr + } + return res } -// WriteMsg write message in console. -func (c *consoleWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > c.Level { +// WriteMsg writes message in console. +func (c *consoleWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > c.Level { return nil } - if c.Colorful { - msg = strings.Replace(msg, levelPrefix[level], colors[level](levelPrefix[level]), 1) - } - c.lg.writeln(when, msg) + msg := c.formatter.Format(lm) + c.lg.writeln(msg) return nil } // Destroy implementing method. empty. func (c *consoleWriter) Destroy() { - } // Flush implementing method. empty. func (c *consoleWriter) Flush() { - } func init() { diff --git a/logs/console_test.go b/core/logs/console_test.go similarity index 66% rename from logs/console_test.go rename to core/logs/console_test.go index 04f2bd7e1e..2f2be0ea04 100644 --- a/logs/console_test.go +++ b/core/logs/console_test.go @@ -16,6 +16,9 @@ package logs import ( "testing" + "time" + + "github.com/stretchr/testify/assert" ) // Try each log level in decreasing order of priority. @@ -49,3 +52,33 @@ func TestConsoleNoColor(t *testing.T) { log.SetLogger("console", `{"color":false}`) testConsoleCalls(log) } + +// Test console async +func TestConsoleAsync(t *testing.T) { + log := NewLogger(100) + log.SetLogger("console") + log.Async() + // log.Close() + testConsoleCalls(log) + for len(log.msgChan) != 0 { + time.Sleep(1 * time.Millisecond) + } + log.Flush() + log.Close() +} + +func TestFormat(t *testing.T) { + log := newConsole() + lm := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + res := log.Format(lm) + assert.Equal(t, "2020/09/19 20:12:37.000 \x1b[1;44m[D]\x1b[0m Cus Hello, world", res) + err := log.WriteMsg(lm) + assert.Nil(t, err) +} diff --git a/core/logs/es/es.go b/core/logs/es/es.go new file mode 100644 index 0000000000..d1ba452d50 --- /dev/null +++ b/core/logs/es/es.go @@ -0,0 +1,123 @@ +package es + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/elastic/go-elasticsearch/v6" + "github.com/elastic/go-elasticsearch/v6/esapi" + + "github.com/beego/beego/v2/core/logs" +) + +// NewES returns a LoggerInterface +func NewES() logs.Logger { + cw := &esLogger{ + Level: logs.LevelDebug, + indexNaming: indexNaming, + } + return cw +} + +// esLogger will log msg into ES +// before you using this implementation, +// please import this package +// usually means that you can import this package in your main package +// for example, anonymous: +// import _ "github.com/beego/beego/v2/core/logs/es" +type esLogger struct { + *elasticsearch.Client + DSN string `json:"dsn"` + Level int `json:"level"` + formatter logs.LogFormatter + Formatter string `json:"formatter"` + + indexNaming IndexNaming +} + +func (el *esLogger) Format(lm *logs.LogMsg) string { + msg := lm.OldStyleFormat() + idx := LogDocument{ + Timestamp: lm.When.Format(time.RFC3339), + Msg: msg, + } + body, err := json.Marshal(idx) + if err != nil { + return msg + } + return string(body) +} + +func (el *esLogger) SetFormatter(f logs.LogFormatter) { + el.formatter = f +} + +// {"dsn":"http://localhost:9200/","level":1} +func (el *esLogger) Init(config string) error { + err := json.Unmarshal([]byte(config), el) + if err != nil { + return err + } + if el.DSN == "" { + return errors.New("empty dsn") + } else if u, err := url.Parse(el.DSN); err != nil { + return err + } else if u.Path == "" { + return errors.New("missing prefix") + } else { + conn, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{el.DSN}, + }) + if err != nil { + return err + } + el.Client = conn + } + if len(el.Formatter) > 0 { + fmtr, ok := logs.GetFormatter(el.Formatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", el.Formatter) + } + el.formatter = fmtr + } + return nil +} + +// WriteMsg writes the msg and level into es +func (el *esLogger) WriteMsg(lm *logs.LogMsg) error { + if lm.Level > el.Level { + return nil + } + + msg := el.formatter.Format(lm) + + req := esapi.IndexRequest{ + Index: indexNaming.IndexName(lm), + DocumentType: "logs", + Body: strings.NewReader(msg), + } + _, err := req.Do(context.Background(), el.Client) + return err +} + +// Destroy is an empty method +func (el *esLogger) Destroy() { +} + +// Flush is an empty method +func (el *esLogger) Flush() { +} + +type LogDocument struct { + Timestamp string `json:"timestamp"` + Msg string `json:"msg"` +} + +func init() { + logs.Register(logs.AdapterEs, NewES) +} diff --git a/core/logs/es/index.go b/core/logs/es/index.go new file mode 100644 index 0000000000..7a31eff189 --- /dev/null +++ b/core/logs/es/index.go @@ -0,0 +1,39 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package es + +import ( + "fmt" + + "github.com/beego/beego/v2/core/logs" +) + +// IndexNaming generate the index name +type IndexNaming interface { + IndexName(lm *logs.LogMsg) string +} + +var indexNaming IndexNaming = &defaultIndexNaming{} + +// SetIndexNaming will register global IndexNaming +func SetIndexNaming(i IndexNaming) { + indexNaming = i +} + +type defaultIndexNaming struct{} + +func (d *defaultIndexNaming) IndexName(lm *logs.LogMsg) string { + return fmt.Sprintf("%04d.%02d.%02d", lm.When.Year(), lm.When.Month(), lm.When.Day()) +} diff --git a/core/logs/es/index_test.go b/core/logs/es/index_test.go new file mode 100644 index 0000000000..ce30763a3e --- /dev/null +++ b/core/logs/es/index_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package es + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/logs" +) + +func TestDefaultIndexNaming_IndexName(t *testing.T) { + tm := time.Date(2020, 9, 12, 1, 34, 45, 234, time.UTC) + lm := &logs.LogMsg{ + When: tm, + } + + res := (&defaultIndexNaming{}).IndexName(lm) + assert.Equal(t, "2020.09.12", res) +} diff --git a/logs/file.go b/core/logs/file.go similarity index 74% rename from logs/file.go rename to core/logs/file.go index 588f786035..2a78cf059b 100644 --- a/logs/file.go +++ b/core/logs/file.go @@ -21,7 +21,6 @@ import ( "fmt" "io" "os" - "path" "path/filepath" "strconv" "strings" @@ -30,9 +29,14 @@ import ( ) // fileLogWriter implements LoggerInterface. -// It writes messages by lines limit, file size limit, or time frequency. +// Writes messages by lines limit, file size limit, or time frequency. type fileLogWriter struct { sync.RWMutex // write log order by order and atomic incr maxLinesCurLines and maxSizeCurSize + + Rotate bool `json:"rotate"` + Daily bool `json:"daily"` + Hourly bool `json:"hourly"` + // The opened file Filename string `json:"filename"` fileWriter *os.File @@ -49,29 +53,30 @@ type fileLogWriter struct { maxSizeCurSize int // Rotate daily - Daily bool `json:"daily"` MaxDays int64 `json:"maxdays"` dailyOpenDate int dailyOpenTime time.Time // Rotate hourly - Hourly bool `json:"hourly"` MaxHours int64 `json:"maxhours"` hourlyOpenDate int hourlyOpenTime time.Time - Rotate bool `json:"rotate"` - Level int `json:"level"` - + // Permissions for log file Perm string `json:"perm"` + // Permissions for directory if it is specified in FileName + DirPerm string `json:"dirperm"` RotatePerm string `json:"rotateperm"` fileNameOnly, suffix string // like "project.log", project is fileNameOnly and .log is suffix + + logFormatter LogFormatter + Formatter string `json:"formatter"` } -// newFileWriter create a FileLogWriter returning as LoggerInterface. +// newFileWriter creates a FileLogWriter returning as LoggerInterface. func newFileWriter() Logger { w := &fileLogWriter{ Daily: true, @@ -82,30 +87,44 @@ func newFileWriter() Logger { RotatePerm: "0440", Level: LevelTrace, Perm: "0660", + DirPerm: "0770", MaxLines: 10000000, MaxFiles: 999, MaxSize: 1 << 28, } + w.logFormatter = w return w } +func (*fileLogWriter) Format(lm *LogMsg) string { + msg := lm.OldStyleFormat() + hd, _, _ := formatTimeHeader(lm.When) + msg = fmt.Sprintf("%s %s\n", string(hd), msg) + return msg +} + +func (w *fileLogWriter) SetFormatter(f LogFormatter) { + w.logFormatter = f +} + // Init file logger with json config. // jsonConfig like: -// { -// "filename":"logs/beego.log", -// "maxLines":10000, -// "maxsize":1024, -// "daily":true, -// "maxDays":15, -// "rotate":true, -// "perm":"0600" -// } -func (w *fileLogWriter) Init(jsonConfig string) error { - err := json.Unmarshal([]byte(jsonConfig), w) +// +// { +// "filename":"logs/beego.log", +// "maxLines":10000, +// "maxsize":1024, +// "daily":true, +// "maxDays":15, +// "rotate":true, +// "perm":"0600" +// } +func (w *fileLogWriter) Init(config string) error { + err := json.Unmarshal([]byte(config), w) if err != nil { return err } - if len(w.Filename) == 0 { + if w.Filename == "" { return errors.New("jsonconfig must have filename") } w.suffix = filepath.Ext(w.Filename) @@ -113,6 +132,14 @@ func (w *fileLogWriter) Init(jsonConfig string) error { if w.suffix == "" { w.suffix = ".log" } + + if len(w.Formatter) > 0 { + fmtr, ok := GetFormatter(w.Formatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", w.Formatter) + } + w.logFormatter = fmtr + } err = w.startLogger() return err } @@ -130,42 +157,43 @@ func (w *fileLogWriter) startLogger() error { return w.initFd() } -func (w *fileLogWriter) needRotateDaily(size int, day int) bool { +func (w *fileLogWriter) needRotateDaily(day int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Daily && day != w.dailyOpenDate) } -func (w *fileLogWriter) needRotateHourly(size int, hour int) bool { +func (w *fileLogWriter) needRotateHourly(hour int) bool { return (w.MaxLines > 0 && w.maxLinesCurLines >= w.MaxLines) || (w.MaxSize > 0 && w.maxSizeCurSize >= w.MaxSize) || (w.Hourly && hour != w.hourlyOpenDate) - } -// WriteMsg write logger message into file. -func (w *fileLogWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > w.Level { +// WriteMsg writes logger message into file. +func (w *fileLogWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > w.Level { return nil } - hd, d, h := formatTimeHeader(when) - msg = string(hd) + msg + "\n" + + _, d, h := formatTimeHeader(lm.When) + + msg := w.logFormatter.Format(lm) if w.Rotate { w.RLock() - if w.needRotateHourly(len(msg), h) { + if w.needRotateHourly(h) { w.RUnlock() w.Lock() - if w.needRotateHourly(len(msg), h) { - if err := w.doRotate(when); err != nil { + if w.needRotateHourly(h) { + if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } } w.Unlock() - } else if w.needRotateDaily(len(msg), d) { + } else if w.needRotateDaily(d) { w.RUnlock() w.Lock() - if w.needRotateDaily(len(msg), d) { - if err := w.doRotate(when); err != nil { + if w.needRotateDaily(d) { + if err := w.doRotate(lm.When); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } } @@ -192,8 +220,13 @@ func (w *fileLogWriter) createLogFile() (*os.File, error) { return nil, err } - filepath := path.Dir(w.Filename) - os.MkdirAll(filepath, os.FileMode(perm)) + dirperm, err := strconv.ParseInt(w.DirPerm, 8, 64) + if err != nil { + return nil, err + } + + filepath := filepath.Dir(w.Filename) + os.MkdirAll(filepath, os.FileMode(dirperm)) fd, err := os.OpenFile(w.Filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE, os.FileMode(perm)) if err == nil { @@ -236,7 +269,7 @@ func (w *fileLogWriter) dailyRotate(openTime time.Time) { tm := time.NewTimer(time.Duration(nextDay.UnixNano() - openTime.UnixNano() + 100)) <-tm.C w.Lock() - if w.needRotateDaily(0, time.Now().Day()) { + if w.needRotateDaily(time.Now().Day()) { if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -251,7 +284,7 @@ func (w *fileLogWriter) hourlyRotate(openTime time.Time) { tm := time.NewTimer(time.Duration(nextHour.UnixNano() - openTime.UnixNano() + 100)) <-tm.C w.Lock() - if w.needRotateHourly(0, time.Now().Hour()) { + if w.needRotateHourly(time.Now().Hour()) { if err := w.doRotate(time.Now()); err != nil { fmt.Fprintf(os.Stderr, "FileLogWriter(%q): %s\n", w.Filename, err) } @@ -286,7 +319,7 @@ func (w *fileLogWriter) lines() (int, error) { return count, nil } -// DoRotate means it need to write file in new file. +// DoRotate means it needs to write logs into a new file. // new file name like xx.2013-01-01.log (daily) or xx.001.log (by line or size) func (w *fileLogWriter) doRotate(logTime time.Time) error { // file exists @@ -302,14 +335,14 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { _, err = os.Lstat(w.Filename) if err != nil { - //even if the file is not exist or other ,we should RESTART the logger + // even if the file is not exist or other ,we should RESTART the logger goto RESTART_LOGGER } if w.Hourly { format = "2006010215" openTime = w.hourlyOpenTime - } else if w.Daily { + } else { format = "2006-01-02" openTime = w.dailyOpenTime } @@ -359,6 +392,10 @@ RESTART_LOGGER: func (w *fileLogWriter) deleteOldLog() { dir := filepath.Dir(w.Filename) + absolutePath, err := filepath.EvalSymlinks(w.Filename) + if err == nil { + dir = filepath.Dir(absolutePath) + } filepath.Walk(dir, func(path string, info os.FileInfo, err error) (returnErr error) { defer func() { if r := recover(); r != nil { @@ -369,21 +406,21 @@ func (w *fileLogWriter) deleteOldLog() { if info == nil { return } - if w.Hourly { - if !info.IsDir() && info.ModTime().Add(1 * time.Hour * time.Duration(w.MaxHours)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } else if w.Daily { - if !info.IsDir() && info.ModTime().Add(24 * time.Hour * time.Duration(w.MaxDays)).Before(time.Now()) { - if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && - strings.HasSuffix(filepath.Base(path), w.suffix) { - os.Remove(path) - } - } - } + if w.Hourly { + if !info.IsDir() && info.ModTime().Add(1*time.Hour*time.Duration(w.MaxHours)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } else if w.Daily { + if !info.IsDir() && info.ModTime().Add(24*time.Hour*time.Duration(w.MaxDays)).Before(time.Now()) { + if strings.HasPrefix(filepath.Base(path), filepath.Base(w.fileNameOnly)) && + strings.HasSuffix(filepath.Base(path), w.suffix) { + os.Remove(path) + } + } + } return }) } @@ -393,7 +430,7 @@ func (w *fileLogWriter) Destroy() { w.fileWriter.Close() } -// Flush flush file logger. +// Flush flushes file logger. // there are no buffering messages in file logger in memory. // flush file means sync file from disk. func (w *fileLogWriter) Flush() { diff --git a/logs/file_test.go b/core/logs/file_test.go similarity index 78% rename from logs/file_test.go rename to core/logs/file_test.go index e7c2ca9aa5..d7ac2d2127 100644 --- a/logs/file_test.go +++ b/core/logs/file_test.go @@ -17,11 +17,12 @@ package logs import ( "bufio" "fmt" - "io/ioutil" "os" "strconv" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestFilePerm(t *testing.T) { @@ -40,12 +41,63 @@ func TestFilePerm(t *testing.T) { if err != nil { t.Fatal(err) } - if file.Mode() != 0666 { + if file.Mode() != 0o666 { t.Fatal("unexpected log file permission") } os.Remove("test.log") } +func TestFileWithPrefixPath(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"log/test.log"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + _, err := os.Stat("log/test.log") + if err != nil { + t.Fatal(err) + } + os.Remove("log/test.log") + os.Remove("log") +} + +func TestFilePermWithPrefixPath(t *testing.T) { + log := NewLogger(10000) + log.SetLogger("file", `{"filename":"mylogpath/test.log", "perm": "0220", "dirperm": "0770"}`) + log.Debug("debug") + log.Informational("info") + log.Notice("notice") + log.Warning("warning") + log.Error("error") + log.Alert("alert") + log.Critical("critical") + log.Emergency("emergency") + + dir, err := os.Stat("mylogpath") + if err != nil { + t.Fatal(err) + } + if !dir.IsDir() { + t.Fatal("mylogpath expected to be a directory") + } + + file, err := os.Stat("mylogpath/test.log") + if err != nil { + t.Fatal(err) + } + if file.Mode() != 0o0220 { + t.Fatal("unexpected file permission") + } + + os.Remove("mylogpath/test.log") + os.Remove("mylogpath") +} + func TestFile1(t *testing.T) { log := NewLogger(10000) log.SetLogger("file", `{"filename":"test.log"}`) @@ -72,7 +124,7 @@ func TestFile1(t *testing.T) { lineNum++ } } - var expected = LevelDebug + 1 + expected := LevelDebug + 1 if lineNum != expected { t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") } @@ -105,7 +157,7 @@ func TestFile2(t *testing.T) { lineNum++ } } - var expected = LevelError + 1 + expected := LevelError + 1 if lineNum != expected { t.Fatal(lineNum, "not "+strconv.Itoa(expected)+" lines") } @@ -162,7 +214,8 @@ func TestFileDailyRotate_05(t *testing.T) { testFileDailyRotate(t, fn1, fn2) os.Remove(fn) } -func TestFileDailyRotate_06(t *testing.T) { //test file mode + +func TestFileDailyRotate_06(t *testing.T) { // test file mode log := NewLogger(10000) log.SetLogger("file", `{"filename":"test3.log","maxlines":4}`) log.Debug("debug") @@ -175,7 +228,7 @@ func TestFileDailyRotate_06(t *testing.T) { //test file mode log.Emergency("emergency") rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006-01-02"), 1) + ".log" s, _ := os.Lstat(rotateName) - if s.Mode() != 0440 { + if s.Mode() != 0o440 { os.Remove(rotateName) os.Remove("test3.log") t.Fatal("rotate file mode error") @@ -186,7 +239,7 @@ func TestFileDailyRotate_06(t *testing.T) { //test file mode func TestFileHourlyRotate_01(t *testing.T) { log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log","hourly":true,"maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -235,9 +288,9 @@ func TestFileHourlyRotate_05(t *testing.T) { os.Remove(fn) } -func TestFileHourlyRotate_06(t *testing.T) { //test file mode +func TestFileHourlyRotate_06(t *testing.T) { // test file mode log := NewLogger(10000) - log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) + log.SetLogger("file", `{"filename":"test3.log", "hourly":true, "maxlines":4}`) log.Debug("debug") log.Info("info") log.Notice("notice") @@ -248,7 +301,7 @@ func TestFileHourlyRotate_06(t *testing.T) { //test file mode log.Emergency("emergency") rotateName := "test3" + fmt.Sprintf(".%s.%03d", time.Now().Format("2006010215"), 1) + ".log" s, _ := os.Lstat(rotateName) - if s.Mode() != 0440 { + if s.Mode() != 0o440 { os.Remove(rotateName) os.Remove("test3.log") t.Fatal("rotate file mode error") @@ -266,22 +319,29 @@ func testFileRotate(t *testing.T, fn1, fn2 string, daily, hourly bool) { Rotate: true, Level: LevelTrace, Perm: "0660", + DirPerm: "0770", RotatePerm: "0440", } + fw.logFormatter = fw - if daily { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) - fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) - fw.dailyOpenDate = fw.dailyOpenTime.Day() - } + if fw.Daily { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) + fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) + fw.dailyOpenDate = fw.dailyOpenTime.Day() + } - if hourly { - fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) - fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) - fw.hourlyOpenDate = fw.hourlyOpenTime.Day() - } + if fw.Hourly { + fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) + fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) + fw.hourlyOpenDate = fw.hourlyOpenTime.Day() + } + lm := &LogMsg{ + Msg: "Test message", + Level: LevelDebug, + When: time.Now(), + } - fw.WriteMsg(time.Now(), "this is a msg for test", LevelDebug) + fw.WriteMsg(lm) for _, file := range []string{fn1, fn2} { _, err := os.Stat(file) @@ -301,8 +361,11 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { Rotate: true, Level: LevelTrace, Perm: "0660", + DirPerm: "0770", RotatePerm: "0440", } + fw.logFormatter = fw + fw.Init(fmt.Sprintf(`{"filename":"%v","maxdays":1}`, fn1)) fw.dailyOpenTime = time.Now().Add(-24 * time.Hour) fw.dailyOpenDate = fw.dailyOpenTime.Day() @@ -314,7 +377,7 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { if err != nil { t.FailNow() } - content, err := ioutil.ReadFile(file) + content, err := os.ReadFile(file) if err != nil { t.FailNow() } @@ -328,13 +391,16 @@ func testFileDailyRotate(t *testing.T, fn1, fn2 string) { func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { fw := &fileLogWriter{ - Hourly: true, - MaxHours: 168, + Hourly: true, + MaxHours: 168, Rotate: true, Level: LevelTrace, Perm: "0660", + DirPerm: "0770", RotatePerm: "0440", } + + fw.logFormatter = fw fw.Init(fmt.Sprintf(`{"filename":"%v","maxhours":1}`, fn1)) fw.hourlyOpenTime = time.Now().Add(-1 * time.Hour) fw.hourlyOpenDate = fw.hourlyOpenTime.Hour() @@ -346,7 +412,7 @@ func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { if err != nil { t.FailNow() } - content, err := ioutil.ReadFile(file) + content, err := os.ReadFile(file) if err != nil { t.FailNow() } @@ -357,6 +423,7 @@ func testFileHourlyRotate(t *testing.T, fn1, fn2 string) { } fw.Destroy() } + func exists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { @@ -418,3 +485,18 @@ func BenchmarkFileOnGoroutine(b *testing.B) { } os.Remove("test4.log") } + +func TestFileLogWriter_Format(t *testing.T) { + lg := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + + fw := newFileWriter().(*fileLogWriter) + res := fw.Format(lg) + assert.Equal(t, "2020/09/19 20:12:37.000 [D] Cus Hello, world\n", res) +} diff --git a/core/logs/formatter.go b/core/logs/formatter.go new file mode 100644 index 0000000000..6fa4f70c5d --- /dev/null +++ b/core/logs/formatter.go @@ -0,0 +1,91 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "fmt" + "path" + "strconv" +) + +var formatterMap = make(map[string]LogFormatter, 4) + +type LogFormatter interface { + Format(lm *LogMsg) string +} + +// PatternLogFormatter provides a quick format method +// for example: +// tes := &PatternLogFormatter{Pattern: "%F:%n|%w %t>> %m", WhenFormat: "2006-01-02"} +// RegisterFormatter("tes", tes) +// SetGlobalFormatter("tes") +type PatternLogFormatter struct { + Pattern string + WhenFormat string +} + +func (p *PatternLogFormatter) getWhenFormatter() string { + s := p.WhenFormat + if s == "" { + s = "2006/01/02 15:04:05.123" // default style + } + return s +} + +func (p *PatternLogFormatter) Format(lm *LogMsg) string { + return p.ToString(lm) +} + +// RegisterFormatter register an formatter. Usually you should use this to extend your custom formatter +// for example: +// RegisterFormatter("my-fmt", &MyFormatter{}) +// logs.SetFormatter(Console, `{"formatter": "my-fmt"}`) +func RegisterFormatter(name string, fmtr LogFormatter) { + formatterMap[name] = fmtr +} + +func GetFormatter(name string) (LogFormatter, bool) { + res, ok := formatterMap[name] + return res, ok +} + +// ToString 'w' when, 'm' msg,'f' filename,'F' full path,'n' line number +// 'l' level number, 't' prefix of level type, 'T' full name of level type +func (p *PatternLogFormatter) ToString(lm *LogMsg) string { + s := []rune(p.Pattern) + msg := fmt.Sprintf(lm.Msg, lm.Args...) + m := map[rune]string{ + 'w': lm.When.Format(p.getWhenFormatter()), + 'm': msg, + 'n': strconv.Itoa(lm.LineNumber), + 'l': strconv.Itoa(lm.Level), + 't': levelPrefix[lm.Level], + 'T': levelNames[lm.Level], + 'F': lm.FilePath, + } + _, m['f'] = path.Split(lm.FilePath) + res := "" + for i := 0; i < len(s)-1; i++ { + if s[i] == '%' { + if k, ok := m[s[i+1]]; ok { + res += k + i++ + continue + } + } + res += string(s[i]) + } + return res +} diff --git a/core/logs/formatter_test.go b/core/logs/formatter_test.go new file mode 100644 index 0000000000..f78a2861e6 --- /dev/null +++ b/core/logs/formatter_test.go @@ -0,0 +1,116 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type CustomFormatter struct{} + +func (*CustomFormatter) Format(lm *LogMsg) string { + return "hello, msg: " + lm.Msg +} + +type TestLogger struct { + Formatter string `json:"formatter"` + Expected string + formatter LogFormatter +} + +func (t *TestLogger) Init(config string) error { + er := json.Unmarshal([]byte(config), t) + t.formatter, _ = GetFormatter(t.Formatter) + return er +} + +func (t *TestLogger) WriteMsg(lm *LogMsg) error { + msg := t.formatter.Format(lm) + if msg != t.Expected { + return errors.New("not equal") + } + return nil +} + +func (*TestLogger) Destroy() { + panic("implement me") +} + +func (*TestLogger) Flush() { + panic("implement me") +} + +func (*TestLogger) SetFormatter(_ LogFormatter) { + panic("implement me") +} + +func TestCustomFormatter(t *testing.T) { + RegisterFormatter("custom", &CustomFormatter{}) + tl := &TestLogger{ + Expected: "hello, msg: world", + } + assert.Nil(t, tl.Init(`{"formatter": "custom"}`)) + assert.Nil(t, tl.WriteMsg(&LogMsg{ + Msg: "world", + })) +} + +func TestPatternLogFormatter(t *testing.T) { + tes := &PatternLogFormatter{ + Pattern: "%F:%n|%w%t>> %m", + WhenFormat: "2006-01-02", + } + when, _ := time.Parse(tes.WhenFormat, "2022-04-17") + testCases := []struct { + msg *LogMsg + want string + }{ + { + msg: &LogMsg{ + Msg: "hello %s", + FilePath: "/User/go/beego/main.go", + Level: LevelWarn, + LineNumber: 10, + When: when, + Args: []interface{}{"world"}, + }, + want: "/User/go/beego/main.go:10|2022-04-17[W]>> hello world", + }, + { + msg: &LogMsg{ + Msg: "hello", + FilePath: "/User/go/beego/main.go", + Level: LevelWarn, + LineNumber: 10, + When: when, + }, + want: "/User/go/beego/main.go:10|2022-04-17[W]>> hello", + }, + { + msg: &LogMsg{}, + want: ":0|0001-01-01[M]>> ", + }, + } + + for _, tc := range testCases { + got := tes.ToString(tc.msg) + assert.Equal(t, tc.want, got) + } +} diff --git a/logs/jianliao.go b/core/logs/jianliao.go similarity index 57% rename from logs/jianliao.go rename to core/logs/jianliao.go index 88ba0f9af4..5c43b40d8d 100644 --- a/logs/jianliao.go +++ b/core/logs/jianliao.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "net/url" - "time" ) // JLWriter implements beego LoggerInterface and is used to send jiaoliao webhook @@ -16,26 +15,49 @@ type JLWriter struct { RedirectURL string `json:"redirecturl,omitempty"` ImageURL string `json:"imageurl,omitempty"` Level int `json:"level"` + + formatter LogFormatter + Formatter string `json:"formatter"` } -// newJLWriter create jiaoliao writer. +// newJLWriter creates jiaoliao writer. func newJLWriter() Logger { - return &JLWriter{Level: LevelTrace} + res := &JLWriter{Level: LevelTrace} + res.formatter = res + return res } // Init JLWriter with json config string -func (s *JLWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) +func (s *JLWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", s.Formatter) + } + s.formatter = fmtr + } + return res +} + +func (s *JLWriter) Format(lm *LogMsg) string { + msg := lm.OldStyleFormat() + msg = fmt.Sprintf("%s %s", lm.When.Format("2006-01-02 15:04:05"), msg) + return msg +} + +func (s *JLWriter) SetFormatter(f LogFormatter) { + s.formatter = f } -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *JLWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { +// WriteMsg writes message in smtp writer. +// Sends an email with subject and only this message. +func (s *JLWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > s.Level { return nil } - text := fmt.Sprintf("%s %s", when.Format("2006-01-02 15:04:05"), msg) + text := s.formatter.Format(lm) form := url.Values{} form.Add("authorName", s.AuthorName) diff --git a/core/logs/jianliao_test.go b/core/logs/jianliao_test.go new file mode 100644 index 0000000000..a1b2d0762a --- /dev/null +++ b/core/logs/jianliao_test.go @@ -0,0 +1,36 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestJLWriter_Format(t *testing.T) { + lg := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + jl := newJLWriter().(*JLWriter) + res := jl.Format(lg) + assert.Equal(t, "2020-09-19 20:12:37 [D] Cus Hello, world", res) +} diff --git a/logs/log.go b/core/logs/log.go similarity index 66% rename from logs/log.go rename to core/logs/log.go index 49f3794f34..aab5be295a 100644 --- a/logs/log.go +++ b/core/logs/log.go @@ -15,7 +15,7 @@ // Package logs provide a general log interface // Usage: // -// import "github.com/astaxie/beego/logs" +// import "github.com/beego/beego/v2/core/logs" // // log := NewLogger(10000) // log.SetLogger("console", "") @@ -29,17 +29,13 @@ // log.Warn("warning") // log.Debug("debug") // log.Critical("critical") -// -// more docs http://beego.me/docs/module/logs.md package logs import ( "fmt" "log" "os" - "path" "runtime" - "strconv" "strings" "sync" "time" @@ -86,13 +82,16 @@ type newLoggerFunc func() Logger // Logger defines the behavior of a log provider. type Logger interface { Init(config string) error - WriteMsg(when time.Time, msg string, level int) error + WriteMsg(lm *LogMsg) error Destroy() Flush() + SetFormatter(f LogFormatter) } -var adapters = make(map[string]newLoggerFunc) -var levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} +var ( + adapters = make(map[string]newLoggerFunc) + levelPrefix = [LevelDebug + 1]string{"[M]", "[A]", "[C]", "[E]", "[W]", "[N]", "[I]", "[D]"} +) // Register makes a log provide available by the provided name. // If Register is called twice with the same name or if driver is nil, @@ -108,20 +107,26 @@ func Register(name string, log newLoggerFunc) { } // BeeLogger is default logger in beego application. -// it can contain several providers and log message into all providers. +// Can contain several providers and log message into all providers. type BeeLogger struct { lock sync.Mutex - level int init bool enableFuncCallDepth bool - loggerFuncCallDepth int + enableFullFilePath bool asynchronous bool + // Whether to discard logs when buffer is full and asynchronous is true + // No discard by default + logWithNonBlocking bool + wg sync.WaitGroup + level int + loggerFuncCallDepth int prefix string msgChanLen int64 - msgChan chan *logMsg - signalChan chan string - wg sync.WaitGroup + msgChan chan *LogMsg + closeChan chan struct{} + flushChan chan struct{} outputs []*nameLogger + globalFormatter string } const defaultAsyncMsgLen = 1e3 @@ -131,31 +136,26 @@ type nameLogger struct { name string } -type logMsg struct { - level int - msg string - when time.Time -} - var logMsgPool *sync.Pool // NewLogger returns a new BeeLogger. -// channelLen means the number of messages in chan(used where asynchronous is true). +// channelLen: the number of messages in chan(used where asynchronous is true). // if the buffering chan is full, logger adapters write to file or other way. func NewLogger(channelLens ...int64) *BeeLogger { bl := new(BeeLogger) bl.level = LevelDebug - bl.loggerFuncCallDepth = 2 + bl.loggerFuncCallDepth = 3 bl.msgChanLen = append(channelLens, 0)[0] if bl.msgChanLen <= 0 { bl.msgChanLen = defaultAsyncMsgLen } - bl.signalChan = make(chan string, 1) + bl.flushChan = make(chan struct{}, 1) + bl.closeChan = make(chan struct{}, 1) bl.setLogger(AdapterConsole) return bl } -// Async set the log to asynchronous and start the goroutine +// Async sets the log to asynchronous and start the goroutine func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { bl.lock.Lock() defer bl.lock.Unlock() @@ -166,10 +166,10 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { if len(msgLen) > 0 && msgLen[0] > 0 { bl.msgChanLen = msgLen[0] } - bl.msgChan = make(chan *logMsg, bl.msgChanLen) + bl.msgChan = make(chan *LogMsg, bl.msgChanLen) logMsgPool = &sync.Pool{ New: func() interface{} { - return &logMsg{} + return &LogMsg{} }, } bl.wg.Add(1) @@ -177,8 +177,18 @@ func (bl *BeeLogger) Async(msgLen ...int64) *BeeLogger { return bl } +// AsyncNonBlockWrite Non-blocking write in asynchronous mode +// Only works if asynchronous write logging is set +func (bl *BeeLogger) AsyncNonBlockWrite() *BeeLogger { + if !bl.asynchronous { + return bl + } + bl.logWithNonBlocking = true + return bl +} + // SetLogger provides a given logger adapter into BeeLogger with config string. -// config need to be correct JSON as string: {"interval":360}. +// config must in JSON format like {"interval":360}} func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { config := append(configs, "{}")[0] for _, l := range bl.outputs { @@ -193,17 +203,27 @@ func (bl *BeeLogger) setLogger(adapterName string, configs ...string) error { } lg := logAdapter() + err := lg.Init(config) if err != nil { - fmt.Fprintln(os.Stderr, "logs.BeeLogger.SetLogger: "+err.Error()) return err } + + // Global formatter overrides the default set formatter + if len(bl.globalFormatter) > 0 { + fmtr, ok := GetFormatter(bl.globalFormatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", bl.globalFormatter) + } + lg.SetFormatter(fmtr) + } + bl.outputs = append(bl.outputs, &nameLogger{name: adapterName, Logger: lg}) return nil } // SetLogger provides a given logger adapter into BeeLogger with config string. -// config need to be correct JSON as string: {"interval":360}. +// config must in JSON format like {"interval":360}} func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { bl.lock.Lock() defer bl.lock.Unlock() @@ -214,11 +234,11 @@ func (bl *BeeLogger) SetLogger(adapterName string, configs ...string) error { return bl.setLogger(adapterName, configs...) } -// DelLogger remove a logger adapter in BeeLogger. +// DelLogger removes a logger adapter in BeeLogger. func (bl *BeeLogger) DelLogger(adapterName string) error { bl.lock.Lock() defer bl.lock.Unlock() - outputs := []*nameLogger{} + outputs := make([]*nameLogger, 0, len(bl.outputs)) for _, lg := range bl.outputs { if lg.name == adapterName { lg.Destroy() @@ -233,9 +253,9 @@ func (bl *BeeLogger) DelLogger(adapterName string) error { return nil } -func (bl *BeeLogger) writeToLoggers(when time.Time, msg string, level int) { +func (bl *BeeLogger) writeToLoggers(lm *LogMsg) { for _, l := range bl.outputs { - err := l.WriteMsg(when, msg, level) + err := l.WriteMsg(lm) if err != nil { fmt.Fprintf(os.Stderr, "unable to WriteMsg to adapter:%v,error:%v\n", l.name, err) } @@ -250,61 +270,83 @@ func (bl *BeeLogger) Write(p []byte) (n int, err error) { if p[len(p)-1] == '\n' { p = p[0 : len(p)-1] } + lm := &LogMsg{ + Msg: string(p), + Level: levelLoggerImpl, + When: time.Now(), + } + // set levelLoggerImpl to ensure all log message will be write out - err = bl.writeMsg(levelLoggerImpl, string(p)) + err = bl.writeMsg(lm) if err == nil { - return len(p), err + return len(p), nil } return 0, err } -func (bl *BeeLogger) writeMsg(logLevel int, msg string, v ...interface{}) error { +func (bl *BeeLogger) writeMsg(lm *LogMsg) error { if !bl.init { bl.lock.Lock() bl.setLogger(AdapterConsole) bl.lock.Unlock() } - if len(v) > 0 { - msg = fmt.Sprintf(msg, v...) - } - - msg = bl.prefix + " " + msg + var ( + file string + line int + ok bool + ) - when := time.Now() - if bl.enableFuncCallDepth { - _, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth) - if !ok { - file = "???" - line = 0 - } - _, filename := path.Split(file) - msg = "[" + filename + ":" + strconv.Itoa(line) + "] " + msg + _, file, line, ok = runtime.Caller(bl.loggerFuncCallDepth) + if !ok { + file = "???" + line = 0 } + lm.FilePath = file + lm.LineNumber = line + lm.Prefix = bl.prefix + + lm.enableFullFilePath = bl.enableFullFilePath + lm.enableFuncCallDepth = bl.enableFuncCallDepth - //set level info in front of filename info - if logLevel == levelLoggerImpl { + // set level info in front of filename info + if lm.Level == levelLoggerImpl { // set to emergency to ensure all log will be print out correctly - logLevel = LevelEmergency - } else { - msg = levelPrefix[logLevel] + " " + msg + lm.Level = LevelEmergency } if bl.asynchronous { - lm := logMsgPool.Get().(*logMsg) - lm.level = logLevel - lm.msg = msg - lm.when = when - bl.msgChan <- lm + logM := logMsgPool.Get().(*LogMsg) + logM.Level = lm.Level + logM.Msg = lm.Msg + logM.When = lm.When + logM.Args = lm.Args + logM.FilePath = lm.FilePath + logM.LineNumber = lm.LineNumber + logM.Prefix = lm.Prefix + + if bl.outputs != nil { + if bl.logWithNonBlocking { + select { + case bl.msgChan <- lm: + // discard log when channel is full + default: + } + } else { + bl.msgChan <- lm + } + } else { + logMsgPool.Put(lm) + } } else { - bl.writeToLoggers(when, msg, logLevel) + bl.writeToLoggers(lm) } return nil } -// SetLevel Set log message level. +// SetLevel sets log message level. // If message level (such as LevelDebug) is higher than logger level (such as LevelWarning), -// log providers will not even be sent the message. +// log providers will not be sent the message. func (bl *BeeLogger) SetLevel(l int) { bl.level = l } @@ -340,19 +382,23 @@ func (bl *BeeLogger) startLogger() { gameOver := false for { select { - case bm := <-bl.msgChan: - bl.writeToLoggers(bm.when, bm.msg, bm.level) - logMsgPool.Put(bm) - case sg := <-bl.signalChan: - // Now should only send "flush" or "close" to bl.signalChan + case bm, ok := <-bl.msgChan: + // this is a terrible design to have a signal channel that accept two inputs + // so we only handle the msg if the channel is not closed + if ok { + bl.writeToLoggers(bm) + logMsgPool.Put(bm) + } + case <-bl.closeChan: bl.flush() - if sg == "close" { - for _, l := range bl.outputs { - l.Destroy() - } - bl.outputs = nil - gameOver = true + for _, l := range bl.outputs { + l.Destroy() } + bl.outputs = nil + gameOver = true + bl.wg.Done() + case <-bl.flushChan: + bl.flush() bl.wg.Done() } if gameOver { @@ -361,12 +407,33 @@ func (bl *BeeLogger) startLogger() { } } +func (bl *BeeLogger) setGlobalFormatter(fmtter string) error { + bl.globalFormatter = fmtter + return nil +} + +// SetGlobalFormatter sets the global formatter for all log adapters +// don't forget to register the formatter by invoking RegisterFormatter +func SetGlobalFormatter(fmtter string) error { + return beeLogger.setGlobalFormatter(fmtter) +} + // Emergency Log EMERGENCY level message. func (bl *BeeLogger) Emergency(format string, v ...interface{}) { if LevelEmergency > bl.level { return } - bl.writeMsg(LevelEmergency, format, v...) + + lm := &LogMsg{ + Level: LevelEmergency, + Msg: format, + When: time.Now(), + } + if len(v) > 0 { + lm.Msg = fmt.Sprintf(lm.Msg, v...) + } + + bl.writeMsg(lm) } // Alert Log ALERT level message. @@ -374,7 +441,14 @@ func (bl *BeeLogger) Alert(format string, v ...interface{}) { if LevelAlert > bl.level { return } - bl.writeMsg(LevelAlert, format, v...) + + lm := &LogMsg{ + Level: LevelAlert, + Msg: format, + When: time.Now(), + Args: v, + } + bl.writeMsg(lm) } // Critical Log CRITICAL level message. @@ -382,7 +456,14 @@ func (bl *BeeLogger) Critical(format string, v ...interface{}) { if LevelCritical > bl.level { return } - bl.writeMsg(LevelCritical, format, v...) + lm := &LogMsg{ + Level: LevelCritical, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Error Log ERROR level message. @@ -390,7 +471,14 @@ func (bl *BeeLogger) Error(format string, v ...interface{}) { if LevelError > bl.level { return } - bl.writeMsg(LevelError, format, v...) + lm := &LogMsg{ + Level: LevelError, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Warning Log WARNING level message. @@ -398,7 +486,14 @@ func (bl *BeeLogger) Warning(format string, v ...interface{}) { if LevelWarn > bl.level { return } - bl.writeMsg(LevelWarn, format, v...) + lm := &LogMsg{ + Level: LevelWarn, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Notice Log NOTICE level message. @@ -406,7 +501,14 @@ func (bl *BeeLogger) Notice(format string, v ...interface{}) { if LevelNotice > bl.level { return } - bl.writeMsg(LevelNotice, format, v...) + lm := &LogMsg{ + Level: LevelNotice, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Informational Log INFORMATIONAL level message. @@ -414,7 +516,14 @@ func (bl *BeeLogger) Informational(format string, v ...interface{}) { if LevelInfo > bl.level { return } - bl.writeMsg(LevelInfo, format, v...) + lm := &LogMsg{ + Level: LevelInfo, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Debug Log DEBUG level message. @@ -422,7 +531,14 @@ func (bl *BeeLogger) Debug(format string, v ...interface{}) { if LevelDebug > bl.level { return } - bl.writeMsg(LevelDebug, format, v...) + lm := &LogMsg{ + Level: LevelDebug, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Warn Log WARN level message. @@ -431,7 +547,14 @@ func (bl *BeeLogger) Warn(format string, v ...interface{}) { if LevelWarn > bl.level { return } - bl.writeMsg(LevelWarn, format, v...) + lm := &LogMsg{ + Level: LevelWarn, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Info Log INFO level message. @@ -440,7 +563,14 @@ func (bl *BeeLogger) Info(format string, v ...interface{}) { if LevelInfo > bl.level { return } - bl.writeMsg(LevelInfo, format, v...) + lm := &LogMsg{ + Level: LevelInfo, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Trace Log TRACE level message. @@ -449,13 +579,20 @@ func (bl *BeeLogger) Trace(format string, v ...interface{}) { if LevelDebug > bl.level { return } - bl.writeMsg(LevelDebug, format, v...) + lm := &LogMsg{ + Level: LevelDebug, + Msg: format, + When: time.Now(), + Args: v, + } + + bl.writeMsg(lm) } // Flush flush all chan data. func (bl *BeeLogger) Flush() { if bl.asynchronous { - bl.signalChan <- "flush" + bl.flushChan <- struct{}{} bl.wg.Wait() bl.wg.Add(1) return @@ -466,7 +603,7 @@ func (bl *BeeLogger) Flush() { // Close close logger, flush all chan data and destroy all adapters in BeeLogger. func (bl *BeeLogger) Close() { if bl.asynchronous { - bl.signalChan <- "close" + bl.closeChan <- struct{}{} bl.wg.Wait() close(bl.msgChan) } else { @@ -476,7 +613,8 @@ func (bl *BeeLogger) Close() { } bl.outputs = nil } - close(bl.signalChan) + close(bl.flushChan) + close(bl.closeChan) } // Reset close all outputs, and set bl.outputs to nil @@ -492,8 +630,11 @@ func (bl *BeeLogger) flush() { if bl.asynchronous { for { if len(bl.msgChan) > 0 { - bm := <-bl.msgChan - bl.writeToLoggers(bm.when, bm.msg, bm.level) + bm, ok := <-bl.msgChan + if !ok { + continue + } + bl.writeToLoggers(bm) logMsgPool.Put(bm) continue } @@ -543,6 +684,12 @@ func GetLogger(prefixes ...string) *log.Logger { return l } +// EnableFullFilePath enables full file path logging. Disabled by default +// e.g "/home/Documents/GitHub/beego/mainapp/" instead of "mainapp" +func EnableFullFilePath(b bool) { + beeLogger.enableFullFilePath = b +} + // Reset will remove all the adapter func Reset() { beeLogger.Reset() @@ -568,10 +715,10 @@ func EnableFuncCallDepth(b bool) { beeLogger.enableFuncCallDepth = b } -// SetLogFuncCall set the CallDepth, default is 4 +// SetLogFuncCall set the CallDepth, default is 3 func SetLogFuncCall(b bool) { beeLogger.EnableFuncCallDepth(b) - beeLogger.SetLogFuncCallDepth(4) + beeLogger.SetLogFuncCallDepth(3) } // SetLogFuncCallDepth set log funcCallDepth @@ -586,61 +733,61 @@ func SetLogger(adapter string, config ...string) error { // Emergency logs a message at emergency level. func Emergency(f interface{}, v ...interface{}) { - beeLogger.Emergency(formatLog(f, v...)) + beeLogger.Emergency(formatPattern(f, v...), v...) } // Alert logs a message at alert level. func Alert(f interface{}, v ...interface{}) { - beeLogger.Alert(formatLog(f, v...)) + beeLogger.Alert(formatPattern(f, v...), v...) } // Critical logs a message at critical level. func Critical(f interface{}, v ...interface{}) { - beeLogger.Critical(formatLog(f, v...)) + beeLogger.Critical(formatPattern(f, v...), v...) } // Error logs a message at error level. func Error(f interface{}, v ...interface{}) { - beeLogger.Error(formatLog(f, v...)) + beeLogger.Error(formatPattern(f, v...), v...) } // Warning logs a message at warning level. func Warning(f interface{}, v ...interface{}) { - beeLogger.Warn(formatLog(f, v...)) + beeLogger.Warn(formatPattern(f, v...), v...) } // Warn compatibility alias for Warning() func Warn(f interface{}, v ...interface{}) { - beeLogger.Warn(formatLog(f, v...)) + beeLogger.Warn(formatPattern(f, v...), v...) } // Notice logs a message at notice level. func Notice(f interface{}, v ...interface{}) { - beeLogger.Notice(formatLog(f, v...)) + beeLogger.Notice(formatPattern(f, v...), v...) } // Informational logs a message at info level. func Informational(f interface{}, v ...interface{}) { - beeLogger.Info(formatLog(f, v...)) + beeLogger.Info(formatPattern(f, v...), v...) } // Info compatibility alias for Warning() func Info(f interface{}, v ...interface{}) { - beeLogger.Info(formatLog(f, v...)) + beeLogger.Info(formatPattern(f, v...), v...) } // Debug logs a message at debug level. func Debug(f interface{}, v ...interface{}) { - beeLogger.Debug(formatLog(f, v...)) + beeLogger.Debug(formatPattern(f, v...), v...) } // Trace logs a message at trace level. // compatibility alias for Warning() func Trace(f interface{}, v ...interface{}) { - beeLogger.Trace(formatLog(f, v...)) + beeLogger.Trace(formatPattern(f, v...), v...) } -func formatLog(f interface{}, v ...interface{}) string { +func formatPattern(f interface{}, v ...interface{}) string { var msg string switch f.(type) { case string: @@ -648,10 +795,8 @@ func formatLog(f interface{}, v ...interface{}) string { if len(v) == 0 { return msg } - if strings.Contains(msg, "%") && !strings.Contains(msg, "%%") { - //format string - } else { - //do not contain format char + if !strings.Contains(msg, "%") { + // do not contain format char msg += strings.Repeat(" %v", len(v)) } default: @@ -661,5 +806,5 @@ func formatLog(f interface{}, v ...interface{}) string { } msg += strings.Repeat(" %v", len(v)) } - return fmt.Sprintf(msg, v...) + return msg } diff --git a/core/logs/log_msg.go b/core/logs/log_msg.go new file mode 100644 index 0000000000..f1a0b559c9 --- /dev/null +++ b/core/logs/log_msg.go @@ -0,0 +1,55 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "fmt" + "path" + "time" +) + +type LogMsg struct { + Level int + Msg string + When time.Time + FilePath string + LineNumber int + Args []interface{} + Prefix string + enableFullFilePath bool + enableFuncCallDepth bool +} + +// OldStyleFormat you should never invoke this +func (lm *LogMsg) OldStyleFormat() string { + msg := lm.Msg + + if len(lm.Args) > 0 { + msg = fmt.Sprintf(lm.Msg, lm.Args...) + } + + msg = lm.Prefix + " " + msg + + if lm.enableFuncCallDepth { + filePath := lm.FilePath + if !lm.enableFullFilePath { + _, filePath = path.Split(filePath) + } + msg = fmt.Sprintf("[%s:%d] %s", filePath, lm.LineNumber, msg) + } + + msg = levelPrefix[lm.Level] + " " + msg + return msg +} diff --git a/core/logs/log_msg_test.go b/core/logs/log_msg_test.go new file mode 100644 index 0000000000..dcc7e7d4fc --- /dev/null +++ b/core/logs/log_msg_test.go @@ -0,0 +1,48 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLogMsg_OldStyleFormat(t *testing.T) { + lg := &LogMsg{ + Level: LevelDebug, + Msg: "Hello, world", + When: time.Date(2020, 9, 19, 20, 12, 37, 9, time.UTC), + FilePath: "/user/home/main.go", + LineNumber: 13, + Prefix: "Cus", + } + res := lg.OldStyleFormat() + assert.Equal(t, "[D] Cus Hello, world", res) + + lg.enableFuncCallDepth = true + res = lg.OldStyleFormat() + assert.Equal(t, "[D] [main.go:13] Cus Hello, world", res) + + lg.enableFullFilePath = true + + res = lg.OldStyleFormat() + assert.Equal(t, "[D] [/user/home/main.go:13] Cus Hello, world", res) + + lg.Msg = "hello, %s" + lg.Args = []interface{}{"world"} + assert.Equal(t, "[D] [/user/home/main.go:13] Cus hello, world", lg.OldStyleFormat()) +} diff --git a/core/logs/log_test.go b/core/logs/log_test.go new file mode 100644 index 0000000000..06f8b0306b --- /dev/null +++ b/core/logs/log_test.go @@ -0,0 +1,128 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +import ( + "encoding/json" + "fmt" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBeeLoggerDelLogger(t *testing.T) { + prefix := "My-Cus" + l := GetLogger(prefix) + assert.NotNil(t, l) + l.Print("hello") + + GetLogger().Print("hello") + SetPrefix("aaa") + Info("hello") +} + +type mockLogger struct { + *logWriter + WriteCost time.Duration `json:"write_cost"` // Simulated log writing time consuming + writeCnt int // Count add 1 when writing log success, just for test result +} + +func NewMockLogger() Logger { + return &mockLogger{ + logWriter: &logWriter{writer: io.Discard}, + } +} + +func (m *mockLogger) Init(config string) error { + return json.Unmarshal([]byte(config), m) +} + +func (m *mockLogger) WriteMsg(lm *LogMsg) error { + m.Lock() + msg := lm.Msg + msg += "\n" + + time.Sleep(m.WriteCost) + if _, err := m.writer.Write([]byte(msg)); err != nil { + return err + } + + m.writeCnt++ + m.Unlock() + return nil +} + +func (m *mockLogger) GetCnt() int { + return m.writeCnt +} + +func (*mockLogger) Destroy() {} +func (*mockLogger) Flush() {} +func (*mockLogger) SetFormatter(_ LogFormatter) {} + +func TestBeeLogger_AsyncNonBlockWrite(t *testing.T) { + testCases := []struct { + name string + before func() + after func() + msgLen int64 + writeCost time.Duration + sendInterval time.Duration + writeCnt int + }{ + { + // Write log time is less than send log time, no blocking + name: "mock1", + after: func() { + _ = beeLogger.DelLogger("mock1") + }, + msgLen: 5, + writeCnt: 10, + writeCost: 200 * time.Millisecond, + sendInterval: 300 * time.Millisecond, + }, + { + // Write log time is less than send log time, discarded when blocking + name: "mock2", + after: func() { + _ = beeLogger.DelLogger("mock2") + }, + writeCnt: 5, + msgLen: 5, + writeCost: 200 * time.Millisecond, + sendInterval: 10 * time.Millisecond, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + Register(tc.name, NewMockLogger) + err := beeLogger.SetLogger(tc.name, fmt.Sprintf(`{"write_cost": %d}`, tc.writeCost)) + assert.Nil(t, err) + + l := beeLogger.Async(tc.msgLen) + l.AsyncNonBlockWrite() + + for i := 0; i < 10; i++ { + time.Sleep(tc.sendInterval) + l.Info(fmt.Sprintf("----%d----", i)) + } + time.Sleep(1 * time.Second) + assert.Equal(t, tc.writeCnt, l.outputs[0].Logger.(*mockLogger).writeCnt) + tc.after() + }) + } +} diff --git a/logs/logger.go b/core/logs/logger.go similarity index 94% rename from logs/logger.go rename to core/logs/logger.go index c7cf8a56ef..29e33a8266 100644 --- a/logs/logger.go +++ b/core/logs/logger.go @@ -30,11 +30,12 @@ func newLogWriter(wr io.Writer) *logWriter { return &logWriter{writer: wr} } -func (lg *logWriter) writeln(when time.Time, msg string) { +func (lg *logWriter) writeln(msg string) (int, error) { lg.Lock() - h, _, _ := formatTimeHeader(when) - lg.writer.Write(append(append(h, msg...), '\n')) + msg += "\n" + n, err := lg.writer.Write([]byte(msg)) lg.Unlock() + return n, err } const ( @@ -59,7 +60,7 @@ func formatTimeHeader(when time.Time) ([]byte, int, int) { y, mo, d := when.Date() h, mi, s := when.Clock() ns := when.Nanosecond() / 1000000 - //len("2006/01/02 15:04:05.123 ")==24 + // len("2006/01/02 15:04:05.123 ")==24 var buf [24]byte buf[0] = y1[y/1000%10] @@ -111,8 +112,10 @@ var ( reset = string([]byte{27, 91, 48, 109}) ) -var once sync.Once -var colorMap map[string]string +var ( + once sync.Once + colorMap map[string]string +) func initColor() { if runtime.GOOS == "windows" { @@ -125,12 +128,12 @@ func initColor() { cyan = w32Cyan } colorMap = map[string]string{ - //by color + // by color "green": green, "white": white, "yellow": yellow, "red": red, - //by method + // by method "GET": blue, "POST": cyan, "PUT": yellow, diff --git a/logs/logger_test.go b/core/logs/logger_test.go similarity index 100% rename from logs/logger_test.go rename to core/logs/logger_test.go diff --git a/logs/multifile.go b/core/logs/multifile.go similarity index 83% rename from logs/multifile.go rename to core/logs/multifile.go index 901682743f..009bd32f70 100644 --- a/logs/multifile.go +++ b/core/logs/multifile.go @@ -16,7 +16,6 @@ package logs import ( "encoding/json" - "time" ) // A filesLogWriter manages several fileLogWriter @@ -54,11 +53,17 @@ func (f *multiFileLogWriter) Init(config string) error { f.fullLogWriter = writer f.writers[LevelDebug+1] = writer - //unmarshal "separate" field to f.Separate - json.Unmarshal([]byte(config), f) + // unmarshal "separate" field to f.Separate + err = json.Unmarshal([]byte(config), f) + if err != nil { + return err + } jsonMap := map[string]interface{}{} - json.Unmarshal([]byte(config), &jsonMap) + err = json.Unmarshal([]byte(config), &jsonMap) + if err != nil { + return err + } for i := LevelEmergency; i < LevelDebug+1; i++ { for _, v := range f.Separate { @@ -75,10 +80,17 @@ func (f *multiFileLogWriter) Init(config string) error { } } } - return nil } +func (*multiFileLogWriter) Format(lm *LogMsg) string { + return lm.OldStyleFormat() +} + +func (f *multiFileLogWriter) SetFormatter(fmt LogFormatter) { + f.fullLogWriter.SetFormatter(fmt) +} + func (f *multiFileLogWriter) Destroy() { for i := 0; i < len(f.writers); i++ { if f.writers[i] != nil { @@ -87,14 +99,14 @@ func (f *multiFileLogWriter) Destroy() { } } -func (f *multiFileLogWriter) WriteMsg(when time.Time, msg string, level int) error { +func (f *multiFileLogWriter) WriteMsg(lm *LogMsg) error { if f.fullLogWriter != nil { - f.fullLogWriter.WriteMsg(when, msg, level) + f.fullLogWriter.WriteMsg(lm) } for i := 0; i < len(f.writers)-1; i++ { if f.writers[i] != nil { - if level == f.writers[i].Level { - f.writers[i].WriteMsg(when, msg, level) + if lm.Level == f.writers[i].Level { + f.writers[i].WriteMsg(lm) } } } @@ -111,7 +123,8 @@ func (f *multiFileLogWriter) Flush() { // newFilesWriter create a FileLogWriter returning as LoggerInterface. func newFilesWriter() Logger { - return &multiFileLogWriter{} + res := &multiFileLogWriter{} + return res } func init() { diff --git a/logs/multifile_test.go b/core/logs/multifile_test.go similarity index 98% rename from logs/multifile_test.go rename to core/logs/multifile_test.go index 57b960945e..8050bc6bd7 100644 --- a/logs/multifile_test.go +++ b/core/logs/multifile_test.go @@ -60,7 +60,7 @@ func TestFiles_1(t *testing.T) { lineNum++ } } - var expected = 1 + expected := 1 if fn == "" { expected = LevelDebug + 1 } @@ -74,5 +74,4 @@ func TestFiles_1(t *testing.T) { } os.Remove(file) } - } diff --git a/core/logs/slack.go b/core/logs/slack.go new file mode 100644 index 0000000000..b631ba55eb --- /dev/null +++ b/core/logs/slack.go @@ -0,0 +1,82 @@ +package logs + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" +) + +// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook +type SLACKWriter struct { + WebhookURL string `json:"webhookurl"` + Level int `json:"level"` + formatter LogFormatter + Formatter string `json:"formatter"` +} + +// newSLACKWriter creates jiaoliao writer. +func newSLACKWriter() Logger { + res := &SLACKWriter{Level: LevelTrace} + res.formatter = res + return res +} + +func (s *SLACKWriter) Format(lm *LogMsg) string { + // text := fmt.Sprintf("{\"text\": \"%s\"}", msg) + return lm.When.Format("2006-01-02 15:04:05") + " " + lm.OldStyleFormat() +} + +func (s *SLACKWriter) SetFormatter(f LogFormatter) { + s.formatter = f +} + +// Init SLACKWriter with json config string +func (s *SLACKWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) + + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", s.Formatter) + } + s.formatter = fmtr + } + + return res +} + +// WriteMsg write message in smtp writer. +// Sends an email with subject and only this message. +func (s *SLACKWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > s.Level { + return nil + } + msg := s.Format(lm) + m := make(map[string]string, 1) + m["text"] = msg + + body, _ := json.Marshal(m) + // resp, err := http.PostForm(s.WebhookURL, form) + resp, err := http.Post(s.WebhookURL, "application/json", bytes.NewReader(body)) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) + } + return nil +} + +// Flush implementing method. empty. +func (s *SLACKWriter) Flush() { +} + +// Destroy implementing method. empty. +func (s *SLACKWriter) Destroy() { +} + +func init() { + Register(AdapterSlack, newSLACKWriter) +} diff --git a/core/logs/slack_test.go b/core/logs/slack_test.go new file mode 100644 index 0000000000..31f5cf97f5 --- /dev/null +++ b/core/logs/slack_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logs + +// func TestSLACKWriter_WriteMsg(t *testing.T) { +// sc := ` +// { +// "webhookurl":"", +// "level":7 +// } +// ` +// l := newSLACKWriter() +// err := l.Init(sc) +// if err != nil { +// Debug(err) +// } +// +// err = l.WriteMsg(&LogMsg{ +// Level: 7, +// Msg: `{ "abs"`, +// When: time.Now(), +// FilePath: "main.go", +// LineNumber: 100, +// enableFullFilePath: true, +// enableFuncCallDepth: true, +// }) +// +// if err != nil { +// Debug(err) +// } +// +// } diff --git a/logs/smtp.go b/core/logs/smtp.go similarity index 69% rename from logs/smtp.go rename to core/logs/smtp.go index 6208d7b859..03ef422064 100644 --- a/logs/smtp.go +++ b/core/logs/smtp.go @@ -21,7 +21,6 @@ import ( "net" "net/smtp" "strings" - "time" ) // SMTPWriter implements LoggerInterface and is used to send emails via given SMTP-server. @@ -33,26 +32,43 @@ type SMTPWriter struct { FromAddress string `json:"fromAddress"` RecipientAddresses []string `json:"sendTos"` Level int `json:"level"` + // InsecureSkipVerify default value: true + InsecureSkipVerify bool `json:"insecureSkipVerify"` + + formatter LogFormatter + Formatter string `json:"formatter"` } -// NewSMTPWriter create smtp writer. +// NewSMTPWriter creates the smtp writer. func newSMTPWriter() Logger { - return &SMTPWriter{Level: LevelTrace} + res := &SMTPWriter{Level: LevelTrace, InsecureSkipVerify: true} + res.formatter = res + return res } // Init smtp writer with json config. // config like: -// { -// "username":"example@gmail.com", -// "password:"password", -// "host":"smtp.gmail.com:465", -// "subject":"email title", -// "fromAddress":"from@example.com", -// "sendTos":["email1","email2"], -// "level":LevelError -// } -func (s *SMTPWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) +// +// { +// "username":"example@gmail.com", +// "password:"password", +// "host":"smtp.gmail.com:465", +// "subject":"email title", +// "fromAddress":"from@example.com", +// "sendTos":["email1","email2"], +// "level":LevelError, +// "insecureSkipVerify": false +// } +func (s *SMTPWriter) Init(config string) error { + res := json.Unmarshal([]byte(config), s) + if res == nil && len(s.Formatter) > 0 { + fmtr, ok := GetFormatter(s.Formatter) + if !ok { + return fmt.Errorf("the formatter with name: %s not found", s.Formatter) + } + s.formatter = fmtr + } + return res } func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { @@ -67,6 +83,10 @@ func (s *SMTPWriter) getSMTPAuth(host string) smtp.Auth { ) } +func (s *SMTPWriter) SetFormatter(f LogFormatter) { + s.formatter = f +} + func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAddress string, recipients []string, msgContent []byte) error { client, err := smtp.Dial(hostAddressWithPort) if err != nil { @@ -75,7 +95,7 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd host, _, _ := net.SplitHostPort(hostAddressWithPort) tlsConn := &tls.Config{ - InsecureSkipVerify: true, + InsecureSkipVerify: s.InsecureSkipVerify, ServerName: host, } if err = client.StartTLS(tlsConn); err != nil { @@ -115,10 +135,14 @@ func (s *SMTPWriter) sendMail(hostAddressWithPort string, auth smtp.Auth, fromAd return client.Quit() } -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { +func (s *SMTPWriter) Format(lm *LogMsg) string { + return lm.OldStyleFormat() +} + +// WriteMsg writes message in smtp writer. +// Sends an email with subject and only this message. +func (s *SMTPWriter) WriteMsg(lm *LogMsg) error { + if lm.Level > s.Level { return nil } @@ -127,11 +151,13 @@ func (s *SMTPWriter) WriteMsg(when time.Time, msg string, level int) error { // Set up authentication information. auth := s.getSMTPAuth(hp[0]) + msg := s.Format(lm) + // Connect to the server, authenticate, set the sender and recipient, // and send the email all in one step. contentType := "Content-Type: text/plain" + "; charset=UTF-8" mailmsg := []byte("To: " + strings.Join(s.RecipientAddresses, ";") + "\r\nFrom: " + s.FromAddress + "<" + s.FromAddress + - ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", when.Format("2006-01-02 15:04:05")) + msg) + ">\r\nSubject: " + s.Subject + "\r\n" + contentType + "\r\n\r\n" + fmt.Sprintf(".%s", lm.When.Format("2006-01-02 15:04:05")) + msg) return s.sendMail(s.Host, auth, s.FromAddress, s.RecipientAddresses, mailmsg) } diff --git a/logs/smtp_test.go b/core/logs/smtp_test.go similarity index 100% rename from logs/smtp_test.go rename to core/logs/smtp_test.go diff --git a/utils/caller.go b/core/utils/caller.go similarity index 100% rename from utils/caller.go rename to core/utils/caller.go diff --git a/utils/caller_test.go b/core/utils/caller_test.go similarity index 100% rename from utils/caller_test.go rename to core/utils/caller_test.go diff --git a/utils/debug.go b/core/utils/debug.go similarity index 90% rename from utils/debug.go rename to core/utils/debug.go index 93c27b70d4..93279181d1 100644 --- a/utils/debug.go +++ b/core/utils/debug.go @@ -47,20 +47,20 @@ func GetDisplayString(data ...interface{}) string { } func display(displayed bool, data ...interface{}) string { - var pc, file, line, ok = runtime.Caller(2) + pc, file, line, ok := runtime.Caller(2) if !ok { return "" } - var buf = new(bytes.Buffer) + buf := new(bytes.Buffer) fmt.Fprintf(buf, "[Debug] at %s() [%s:%d]\n", function(pc), file, line) fmt.Fprintf(buf, "\n[Variables]\n") for i := 0; i < len(data); i += 2 { - var output = fomateinfo(len(data[i].(string))+3, data[i+1]) + output := fomateinfo(len(data[i].(string))+3, data[i+1]) fmt.Fprintf(buf, "%s = %s", data[i], output) } @@ -72,7 +72,7 @@ func display(displayed bool, data ...interface{}) string { // return data dump and format bytes func fomateinfo(headlen int, data ...interface{}) []byte { - var buf = new(bytes.Buffer) + buf := new(bytes.Buffer) if len(data) > 1 { fmt.Fprint(buf, " ") @@ -83,9 +83,9 @@ func fomateinfo(headlen int, data ...interface{}) []byte { } for k, v := range data { - var buf2 = new(bytes.Buffer) + buf2 := new(bytes.Buffer) var pointers *pointerInfo - var interfaces = make([]reflect.Value, 0, 10) + interfaces := make([]reflect.Value, 0, 10) printKeyValue(buf2, reflect.ValueOf(v), &pointers, &interfaces, nil, true, " ", 1) @@ -140,13 +140,13 @@ func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, return true } - var elem = val.Elem() + elem := val.Elem() if isSimpleType(elem, elem.Kind(), pointers, interfaces) { return true } - var addr = val.Elem().UnsafeAddr() + addr := val.Elem().UnsafeAddr() for p := *pointers; p != nil; p = p.prev { if addr == p.addr { @@ -162,7 +162,7 @@ func isSimpleType(val reflect.Value, kind reflect.Kind, pointers **pointerInfo, // dump value func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, interfaces *[]reflect.Value, structFilter func(string, string) bool, formatOutput bool, indent string, level int) { - var t = val.Kind() + t := val.Kind() switch t { case reflect.Bool: @@ -183,7 +183,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, return } - var addr = val.Elem().UnsafeAddr() + addr := val.Elem().UnsafeAddr() for p := *pointers; p != nil; p = p.prev { if addr == p.addr { @@ -206,7 +206,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, case reflect.String: fmt.Fprint(buf, "\"", val.String(), "\"") case reflect.Interface: - var value = val.Elem() + value := val.Elem() if !value.IsValid() { fmt.Fprint(buf, "nil") @@ -223,7 +223,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, printKeyValue(buf, value, pointers, interfaces, structFilter, formatOutput, indent, level+1) } case reflect.Struct: - var t = val.Type() + t := val.Type() fmt.Fprint(buf, t) fmt.Fprint(buf, "{") @@ -235,7 +235,7 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, fmt.Fprint(buf, " ") } - var name = t.Field(i).Name + name := t.Field(i).Name if formatOutput { for ind := 0; ind < level; ind++ { @@ -270,12 +270,12 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, fmt.Fprint(buf, val.Type()) fmt.Fprint(buf, "{") - var allSimple = true + allSimple := true for i := 0; i < val.Len(); i++ { - var elem = val.Index(i) + elem := val.Index(i) - var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + isSimple := isSimpleType(elem, elem.Kind(), pointers, interfaces) if !isSimple { allSimple = false @@ -312,18 +312,18 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, fmt.Fprint(buf, "}") case reflect.Map: - var t = val.Type() - var keys = val.MapKeys() + t := val.Type() + keys := val.MapKeys() fmt.Fprint(buf, t) fmt.Fprint(buf, "{") - var allSimple = true + allSimple := true for i := 0; i < len(keys); i++ { - var elem = val.MapIndex(keys[i]) + elem := val.MapIndex(keys[i]) - var isSimple = isSimpleType(elem, elem.Kind(), pointers, interfaces) + isSimple := isSimpleType(elem, elem.Kind(), pointers, interfaces) if !isSimple { allSimple = false @@ -372,8 +372,8 @@ func printKeyValue(buf *bytes.Buffer, val reflect.Value, pointers **pointerInfo, // PrintPointerInfo dump pointer value func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { - var anyused = false - var pointerNum = 0 + anyused := false + pointerNum := 0 for p := pointers; p != nil; p = p.prev { if len(p.used) > 0 { @@ -384,10 +384,10 @@ func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { } if anyused { - var pointerBufs = make([][]rune, pointerNum+1) + pointerBufs := make([][]rune, pointerNum+1) for i := 0; i < len(pointerBufs); i++ { - var pointerBuf = make([]rune, buf.Len()+headlen) + pointerBuf := make([]rune, buf.Len()+headlen) for j := 0; j < len(pointerBuf); j++ { pointerBuf[j] = ' ' @@ -402,7 +402,7 @@ func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { if pn == p.n { pointerBufs[pn][p.pos+headlen] = '└' - var maxpos = 0 + maxpos := 0 for i, pos := range p.used { if i < len(p.used)-1 { @@ -440,10 +440,10 @@ func PrintPointerInfo(buf *bytes.Buffer, headlen int, pointers *pointerInfo) { // Stack get stack bytes func Stack(skip int, indent string) []byte { - var buf = new(bytes.Buffer) + buf := new(bytes.Buffer) for i := skip; ; i++ { - var pc, file, line, ok = runtime.Caller(i) + pc, file, line, ok := runtime.Caller(i) if !ok { break diff --git a/utils/debug_test.go b/core/utils/debug_test.go similarity index 95% rename from utils/debug_test.go rename to core/utils/debug_test.go index efb8924ec9..a748d20aa3 100644 --- a/utils/debug_test.go +++ b/core/utils/debug_test.go @@ -28,8 +28,8 @@ func TestPrint(t *testing.T) { } func TestPrintPoint(t *testing.T) { - var v1 = new(mytype) - var v2 = new(mytype) + v1 := new(mytype) + v2 := new(mytype) v1.prev = nil v1.next = v2 diff --git a/utils/file.go b/core/utils/file.go similarity index 99% rename from utils/file.go rename to core/utils/file.go index 6090eb1710..2310a6b326 100644 --- a/utils/file.go +++ b/core/utils/file.go @@ -69,6 +69,7 @@ func GrepFile(patten string, filename string) (lines []string, err error) { if err != nil { return } + defer fd.Close() lines = make([]string, 0) reader := bufio.NewReader(fd) prefix := "" diff --git a/utils/file_test.go b/core/utils/file_test.go similarity index 100% rename from utils/file_test.go rename to core/utils/file_test.go diff --git a/core/utils/kv.go b/core/utils/kv.go new file mode 100644 index 0000000000..7fc5cee93c --- /dev/null +++ b/core/utils/kv.go @@ -0,0 +1,87 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +type KV interface { + GetKey() interface{} + GetValue() interface{} +} + +// SimpleKV is common structure to store key-value pairs. +// When you need something like Pair, you can use this +type SimpleKV struct { + Key interface{} + Value interface{} +} + +var _ KV = new(SimpleKV) + +func (s *SimpleKV) GetKey() interface{} { + return s.Key +} + +func (s *SimpleKV) GetValue() interface{} { + return s.Value +} + +// KVs interface +type KVs interface { + GetValueOr(key interface{}, defValue interface{}) interface{} + Contains(key interface{}) bool + IfContains(key interface{}, action func(value interface{})) KVs +} + +// SimpleKVs will store SimpleKV collection as map +type SimpleKVs struct { + kvs map[interface{}]interface{} +} + +var _ KVs = new(SimpleKVs) + +// GetValueOr returns the value for a given key, if non-existent +// it returns defValue +func (kvs *SimpleKVs) GetValueOr(key interface{}, defValue interface{}) interface{} { + v, ok := kvs.kvs[key] + if ok { + return v + } + return defValue +} + +// Contains checks if a key exists +func (kvs *SimpleKVs) Contains(key interface{}) bool { + _, ok := kvs.kvs[key] + return ok +} + +// IfContains invokes the action on a key if it exists +func (kvs *SimpleKVs) IfContains(key interface{}, action func(value interface{})) KVs { + v, ok := kvs.kvs[key] + if ok { + action(v) + } + return kvs +} + +// NewKVs creates the *KVs instance +func NewKVs(kvs ...KV) KVs { + res := &SimpleKVs{ + kvs: make(map[interface{}]interface{}, len(kvs)), + } + for _, kv := range kvs { + res.kvs[kv.GetKey()] = kv.GetValue() + } + return res +} diff --git a/core/utils/kv_test.go b/core/utils/kv_test.go new file mode 100644 index 0000000000..cd0dc5c398 --- /dev/null +++ b/core/utils/kv_test.go @@ -0,0 +1,37 @@ +// Copyright 2020 beego-dev +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKVs(t *testing.T) { + key := "my-key" + kvs := NewKVs(&SimpleKV{ + Key: key, + Value: 12, + }) + + assert.True(t, kvs.Contains(key)) + + v := kvs.GetValueOr(key, 13) + assert.Equal(t, 12, v) + + v = kvs.GetValueOr(`key-not-exists`, 8546) + assert.Equal(t, 8546, v) +} diff --git a/utils/mail.go b/core/utils/mail.go similarity index 97% rename from utils/mail.go rename to core/utils/mail.go index e3fa1c9099..0f1f95ccf1 100644 --- a/utils/mail.go +++ b/core/utils/mail.go @@ -120,7 +120,7 @@ func (e *Email) Bytes() ([]byte, error) { } // Create the body sections if e.Text != "" { - header.Set("Content-Type", fmt.Sprintf("text/plain; charset=UTF-8")) + header.Set("Content-Type", "text/plain; charset=UTF-8") header.Set("Content-Transfer-Encoding", "quoted-printable") if _, err := subWriter.CreatePart(header); err != nil { return nil, err @@ -131,7 +131,7 @@ func (e *Email) Bytes() ([]byte, error) { } } if e.HTML != "" { - header.Set("Content-Type", fmt.Sprintf("text/html; charset=UTF-8")) + header.Set("Content-Type", "text/html; charset=UTF-8") header.Set("Content-Transfer-Encoding", "quoted-printable") if _, err := subWriter.CreatePart(header); err != nil { return nil, err @@ -162,7 +162,7 @@ func (e *Email) Bytes() ([]byte, error) { // AttachFile Add attach file to the send mail func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { - if len(args) < 1 && len(args) > 2 { + if len(args) < 1 || len(args) > 2 { // change && to || err = errors.New("Must specify a file name and number of parameters can not exceed at least two") return } @@ -175,6 +175,7 @@ func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { if err != nil { return } + defer f.Close() ct := mime.TypeByExtension(filepath.Ext(filename)) basename := path.Base(filename) return e.Attach(f, basename, ct, id) @@ -183,14 +184,14 @@ func (e *Email) AttachFile(args ...string) (a *Attachment, err error) { // Attach is used to attach content from an io.Reader to the email. // Parameters include an io.Reader, the desired filename for the attachment, and the Content-Type. func (e *Email) Attach(r io.Reader, filename string, args ...string) (a *Attachment, err error) { - if len(args) < 1 && len(args) > 2 { + if len(args) < 1 || len(args) > 2 { // change && to || err = errors.New("Must specify the file type and number of parameters can not exceed at least two") return } - c := args[0] //Content-Type + c := args[0] // Content-Type id := "" if len(args) > 1 { - id = args[1] //Content-ID + id = args[1] // Content-ID } var buffer bytes.Buffer if _, err = io.Copy(&buffer, r); err != nil { diff --git a/utils/mail_test.go b/core/utils/mail_test.go similarity index 100% rename from utils/mail_test.go rename to core/utils/mail_test.go diff --git a/core/utils/pagination/doc.go b/core/utils/pagination/doc.go new file mode 100644 index 0000000000..78beccd338 --- /dev/null +++ b/core/utils/pagination/doc.go @@ -0,0 +1,52 @@ +/* +Package pagination provides utilities to setup a paginator within the +context of a http request. + +# Usage + +In your beego.Controller: + + package controllers + + import "github.com/beego/beego/v2/core/utils/pagination" + + type PostsController struct { + beego.Controller + } + + func (this *PostsController) ListAllPosts() { + // sets this.Data["paginator"] with the current offset (from the url query param) + postsPerPage := 20 + paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) + + // fetch the next 20 posts + this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) + } + +In your view templates: + + {{if .paginator.HasPages}} + + {{end}} +*/ +package pagination diff --git a/utils/pagination/paginator.go b/core/utils/pagination/paginator.go similarity index 95% rename from utils/pagination/paginator.go rename to core/utils/pagination/paginator.go index c6db31e082..6a9ca7c415 100644 --- a/utils/pagination/paginator.go +++ b/core/utils/pagination/paginator.go @@ -78,11 +78,11 @@ func (p *Paginator) Page() int { // // Usage (in a view template): // -// {{range $index, $page := .paginator.Pages}} -// -// {{$page}} -// -// {{end}} +// {{range $index, $page := .paginator.Pages}} +// +// {{$page}} +// +// {{end}} func (p *Paginator) Pages() []int { if p.pageRange == nil && p.nums > 0 { var pages []int diff --git a/utils/pagination/utils.go b/core/utils/pagination/utils.go similarity index 100% rename from utils/pagination/utils.go rename to core/utils/pagination/utils.go diff --git a/utils/rand.go b/core/utils/rand.go similarity index 97% rename from utils/rand.go rename to core/utils/rand.go index 344d1cd534..f80dad7df8 100644 --- a/utils/rand.go +++ b/core/utils/rand.go @@ -27,7 +27,7 @@ func RandomCreateBytes(n int, alphabets ...byte) []byte { if len(alphabets) == 0 { alphabets = alphaNum } - var bytes = make([]byte, n) + bytes := make([]byte, n) var randBy bool if num, err := rand.Read(bytes); num != n || err != nil { r.Seed(time.Now().UnixNano()) diff --git a/utils/rand_test.go b/core/utils/rand_test.go similarity index 100% rename from utils/rand_test.go rename to core/utils/rand_test.go diff --git a/utils/safemap.go b/core/utils/safemap.go similarity index 98% rename from utils/safemap.go rename to core/utils/safemap.go index 1793030a5f..8c48cb5b68 100644 --- a/utils/safemap.go +++ b/core/utils/safemap.go @@ -18,7 +18,7 @@ import ( "sync" ) -// BeeMap is a map with lock +// Deprecated: using sync.Map type BeeMap struct { lock *sync.RWMutex bm map[interface{}]interface{} diff --git a/utils/safemap_test.go b/core/utils/safemap_test.go similarity index 100% rename from utils/safemap_test.go rename to core/utils/safemap_test.go diff --git a/utils/slice.go b/core/utils/slice.go similarity index 99% rename from utils/slice.go rename to core/utils/slice.go index 8f2cef980f..0eda467036 100644 --- a/utils/slice.go +++ b/core/utils/slice.go @@ -20,6 +20,7 @@ import ( ) type reducetype func(interface{}) interface{} + type filtertype func(interface{}) bool // InSlice checks given string in string slice or not. diff --git a/utils/slice_test.go b/core/utils/slice_test.go similarity index 100% rename from utils/slice_test.go rename to core/utils/slice_test.go diff --git a/utils/testdata/grepe.test b/core/utils/testdata/grepe.test similarity index 100% rename from utils/testdata/grepe.test rename to core/utils/testdata/grepe.test diff --git a/core/utils/time.go b/core/utils/time.go new file mode 100644 index 0000000000..2f813a0575 --- /dev/null +++ b/core/utils/time.go @@ -0,0 +1,46 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "time" +) + +// ToShortTimeFormat short string format +func ToShortTimeFormat(d time.Duration) string { + u := uint64(d) + if u < uint64(time.Second) { + switch { + case u == 0: + return "0" + case u < uint64(time.Microsecond): + return fmt.Sprintf("%.2fns", float64(u)) + case u < uint64(time.Millisecond): + return fmt.Sprintf("%.2fus", float64(u)/1000) + default: + return fmt.Sprintf("%.2fms", float64(u)/1000/1000) + } + } else { + switch { + case u < uint64(time.Minute): + return fmt.Sprintf("%.2fs", float64(u)/1000/1000/1000) + case u < uint64(time.Hour): + return fmt.Sprintf("%.2fm", float64(u)/1000/1000/1000/60) + default: + return fmt.Sprintf("%.2fh", float64(u)/1000/1000/1000/60/60) + } + } +} diff --git a/core/utils/utils.go b/core/utils/utils.go new file mode 100644 index 0000000000..3874b803b1 --- /dev/null +++ b/core/utils/utils.go @@ -0,0 +1,89 @@ +package utils + +import ( + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" +) + +// GetGOPATHs returns all paths in GOPATH variable. +func GetGOPATHs() []string { + gopath := os.Getenv("GOPATH") + if gopath == "" && compareGoVersion(runtime.Version(), "go1.8") >= 0 { + gopath = defaultGOPATH() + } + return filepath.SplitList(gopath) +} + +func compareGoVersion(a, b string) int { + reg := regexp.MustCompile("^\\d*") + + a = strings.TrimPrefix(a, "go") + b = strings.TrimPrefix(b, "go") + + versionsA := strings.Split(a, ".") + versionsB := strings.Split(b, ".") + + for i := 0; i < len(versionsA) && i < len(versionsB); i++ { + versionA := versionsA[i] + versionB := versionsB[i] + + vA, err := strconv.Atoi(versionA) + if err != nil { + str := reg.FindString(versionA) + if str != "" { + vA, _ = strconv.Atoi(str) + } else { + vA = -1 + } + } + + vB, err := strconv.Atoi(versionB) + if err != nil { + str := reg.FindString(versionB) + if str != "" { + vB, _ = strconv.Atoi(str) + } else { + vB = -1 + } + } + + if vA > vB { + // vA = 12, vB = 8 + return 1 + } else if vA < vB { + // vA = 6, vB = 8 + return -1 + } else if vA == -1 { + // vA = rc1, vB = rc3 + return strings.Compare(versionA, versionB) + } + + // vA = vB = 8 + continue + } + + if len(versionsA) > len(versionsB) { + return 1 + } else if len(versionsA) == len(versionsB) { + return 0 + } + + return -1 +} + +func defaultGOPATH() string { + env := "HOME" + if runtime.GOOS == "windows" { + env = "USERPROFILE" + } else if runtime.GOOS == "plan9" { + env = "home" + } + if home := os.Getenv(env); home != "" { + return filepath.Join(home, "go") + } + return "" +} diff --git a/core/utils/utils_test.go b/core/utils/utils_test.go new file mode 100644 index 0000000000..ced6f63fe2 --- /dev/null +++ b/core/utils/utils_test.go @@ -0,0 +1,36 @@ +package utils + +import ( + "testing" +) + +func TestCompareGoVersion(t *testing.T) { + targetVersion := "go1.8" + if compareGoVersion("go1.12.4", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8.7", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7.6", targetVersion) != -1 { + t.Error("should be -1") + } + + if compareGoVersion("go1.12.1rc1", targetVersion) != 1 { + t.Error("should be 1") + } + + if compareGoVersion("go1.8rc1", targetVersion) != 0 { + t.Error("should be 0") + } + + if compareGoVersion("go1.7rc1", targetVersion) != -1 { + t.Error("should be -1") + } +} diff --git a/validation/README.md b/core/validation/README.md similarity index 90% rename from validation/README.md rename to core/validation/README.md index 43373e47d7..0f9c167762 100644 --- a/validation/README.md +++ b/core/validation/README.md @@ -7,18 +7,18 @@ validation is a form validation for a data validation and error collecting using Install: - go get github.com/astaxie/beego/validation + go get github.com/beego/beego/v2/core/validation Test: - go test github.com/astaxie/beego/validation + go test github.com/beego/beego/v2/core/validation ## Example Direct Use: import ( - "github.com/astaxie/beego/validation" + "github.com/beego/beego/v2/core/validation" "log" ) @@ -49,7 +49,7 @@ Direct Use: Struct Tag Use: import ( - "github.com/astaxie/beego/validation" + "github.com/beego/beego/v2/core/validation" ) // validation function follow with "valid" tag @@ -81,7 +81,7 @@ Struct Tag Use: Use custom function: import ( - "github.com/astaxie/beego/validation" + "github.com/beego/beego/v2/core/validation" ) type user struct { @@ -140,7 +140,7 @@ Struct Tag Functions: Tel Phone ZipCode - + Enum ## LICENSE diff --git a/validation/util.go b/core/validation/util.go similarity index 92% rename from validation/util.go rename to core/validation/util.go index 66fce28372..69a4324369 100644 --- a/validation/util.go +++ b/core/validation/util.go @@ -26,6 +26,8 @@ const ( // ValidTag struct tag ValidTag = "valid" + LabelTag = "label" + wordsize = 32 << (^uint(0) >> 32 & 1) ) @@ -65,13 +67,15 @@ type CustomFunc func(v *Validation, obj interface{}, key string) // AddCustomFunc Add a custom function to validation // The name can not be: -// Clear -// HasErrors -// ErrorMap -// Error -// Check -// Valid -// NoMatch +// +// Clear +// HasErrors +// ErrorMap +// Error +// Check +// Valid +// NoMatch +// // If the name is same with exists function, it will replace the origin valid function func AddCustomFunc(name string, f CustomFunc) error { if unFuncs[name] { @@ -124,6 +128,7 @@ func isStructPtr(t reflect.Type) bool { func getValidFuncs(f reflect.StructField) (vfs []ValidFunc, err error) { tag := f.Tag.Get(ValidTag) + label := f.Tag.Get(LabelTag) if len(tag) == 0 { return } @@ -136,7 +141,7 @@ func getValidFuncs(f reflect.StructField) (vfs []ValidFunc, err error) { if len(vfunc) == 0 { continue } - vf, err = parseFunc(vfunc, f.Name) + vf, err = parseFunc(vfunc, f.Name, label) if err != nil { return } @@ -168,7 +173,7 @@ func getRegFuncs(tag, key string) (vfs []ValidFunc, str string, err error) { return } -func parseFunc(vfunc, key string) (v ValidFunc, err error) { +func parseFunc(vfunc, key string, label string) (v ValidFunc, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("%v", r) @@ -188,7 +193,7 @@ func parseFunc(vfunc, key string) (v ValidFunc, err error) { err = fmt.Errorf("%s require %d parameters", vfunc, num) return } - v = ValidFunc{vfunc, []interface{}{key + "." + vfunc}} + v = ValidFunc{vfunc, []interface{}{key + "." + vfunc + "." + label}} return } @@ -210,7 +215,7 @@ func parseFunc(vfunc, key string) (v ValidFunc, err error) { return } - tParams, err := trim(name, key+"."+name, params) + tParams, err := trim(name, key+"."+name+"."+label, params) if err != nil { return } @@ -221,7 +226,7 @@ func parseFunc(vfunc, key string) (v ValidFunc, err error) { func numIn(name string) (num int, err error) { fn, ok := funcs[name] if !ok { - err = fmt.Errorf("doesn't exsits %s valid function", name) + err = fmt.Errorf("doesn't exists %s valid function", name) return } // sub *Validation obj and key @@ -233,7 +238,7 @@ func trim(name, key string, s []string) (ts []interface{}, err error) { ts = make([]interface{}, len(s), len(s)+1) fn, ok := funcs[name] if !ok { - err = fmt.Errorf("doesn't exsits %s valid function", name) + err = fmt.Errorf("doesn't exists %s valid function", name) return } for i := 0; i < len(s); i++ { diff --git a/validation/util_test.go b/core/validation/util_test.go similarity index 66% rename from validation/util_test.go rename to core/validation/util_test.go index e74d50edfa..d9227082a6 100644 --- a/validation/util_test.go +++ b/core/validation/util_test.go @@ -15,6 +15,8 @@ package validation import ( + "fmt" + "log" "reflect" "testing" ) @@ -23,7 +25,7 @@ type user struct { ID int Tag string `valid:"Maxx(aa)"` Name string `valid:"Required;"` - Age int `valid:"Required;Range(1, 140)"` + Age int `valid:"Required; Range(1, 140)"` match string `valid:"Required; Match(/^(test)?\\w*@(/test/);com$/);Max(2)"` } @@ -42,7 +44,7 @@ func TestGetValidFuncs(t *testing.T) { } f, _ = tf.FieldByName("Tag") - if _, err = getValidFuncs(f); err.Error() != "doesn't exsits Maxx valid function" { + if _, err = getValidFuncs(f); err.Error() != "doesn't exists Maxx valid function" { t.Fatal(err) } @@ -80,6 +82,33 @@ func TestGetValidFuncs(t *testing.T) { } } +type User struct { + Name string `valid:"Required;MaxSize(5)" ` + Sex string `valid:"Required;" label:"sex_label"` + Age int `valid:"Required;Range(1, 140);" label:"age_label"` +} + +func TestValidation(t *testing.T) { + u := User{"man1238888456", "", 1140} + valid := Validation{} + b, err := valid.Valid(&u) + if err != nil { + // handle error + } + if !b { + // validation does not pass + // blabla... + for _, err := range valid.Errors { + log.Println(err.Key, err.Message) + } + if len(valid.Errors) != 3 { + t.Error("must be has 3 error") + } + } else { + t.Error("must be has 3 error") + } +} + func TestCall(t *testing.T) { u := user{Name: "test", Age: 180} tf := reflect.TypeOf(u) @@ -98,3 +127,29 @@ func TestCall(t *testing.T) { t.Error("age out of range should be has an error") } } + +func ExampleAddCustomFunc() { + err := AddCustomFunc("MyFunc", func(v *Validation, obj interface{}, key string) { + // do validation, and if you find something wrong, + // call AddError + v.AddError(key, "this is my error") + }) + if err != nil { + panic(err) + } + type MyUser struct { + Name string `valid:"MyFunc"` + } + v := Validation{} + ok, err := v.Valid(&MyUser{}) + if err != nil { + panic(err) + } + if !ok { + // get the validation error here + errs := v.Errors + fmt.Println(errs[0].Error()) + } + // Output: + // Name this is my error +} diff --git a/validation/validation.go b/core/validation/validation.go similarity index 91% rename from validation/validation.go rename to core/validation/validation.go index ca1e211fcc..b206520a09 100644 --- a/validation/validation.go +++ b/core/validation/validation.go @@ -15,7 +15,7 @@ // Package validation for validations // // import ( -// "github.com/astaxie/beego/validation" +// "github.com/beego/beego/v2/core/validation" // "log" // ) // @@ -42,8 +42,6 @@ // log.Println(v.Error.Key, v.Error.Message) // } // } -// -// more info: http://beego.me/docs/mvc/controller/validation.md package validation import ( @@ -60,9 +58,9 @@ type ValidFormer interface { // Error show the error type Error struct { - Message, Key, Name, Field, Tmpl string - Value interface{} - LimitValue interface{} + Message, Key, Name, Field, Tmpl, Label string + Value interface{} + LimitValue interface{} } // String Returns the Message. @@ -107,7 +105,7 @@ func (r *Result) Message(message string, args ...interface{}) *Result { // A Validation context manages data validation and error messages. type Validation struct { // if this field set true, in struct tag valid - // if the struct field vale is empty + // if the struct field value is empty // it will skip those valid functions, see CanSkipFuncs RequiredFirst bool @@ -121,7 +119,7 @@ func (v *Validation) Clear() { v.ErrorsMap = nil } -// HasErrors Has ValidationError nor not. +// HasErrors Has ValidationError or not. func (v *Validation) HasErrors() bool { return len(v.Errors) > 0 } @@ -158,7 +156,7 @@ func (v *Validation) Max(obj interface{}, max int, key string) *Result { return v.apply(Max{max, key}, obj) } -// Range Test that the obj is between mni and max if obj's type is int +// Range Test that the obj is between min and max if obj's type is int func (v *Validation) Range(obj interface{}, min, max int, key string) *Result { return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj) } @@ -235,8 +233,11 @@ func (v *Validation) Tel(obj interface{}, key string) *Result { // Phone Test that the obj is chinese mobile or telephone number if type is string func (v *Validation) Phone(obj interface{}, key string) *Result { - return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}}, - Tel{Match: Match{Regexp: telPattern}}, key}, obj) + return v.apply(Phone{ + Mobile{Match: Match{Regexp: mobilePattern}}, + Tel{Match: Match{Regexp: telPattern}}, + key, + }, obj) } // ZipCode Test that the obj is chinese zip code if type is string @@ -244,6 +245,11 @@ func (v *Validation) ZipCode(obj interface{}, key string) *Result { return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj) } +// Enum Test that the obj is in the specified enumeration if type is string +func (v *Validation) Enum(obj interface{}, vals string, key string) *Result { + return v.apply(Enum{vals, key}, obj) +} + func (v *Validation) apply(chk Validator, obj interface{}) *Result { if nil == obj { if chk.IsSatisfied(obj) { @@ -267,21 +273,31 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result { key := chk.GetKey() Name := key Field := "" - + Label := "" parts := strings.Split(key, ".") if len(parts) == 2 { Field = parts[0] Name = parts[1] + Label = Field + } + if len(parts) == 3 { + Field = parts[0] + Name = parts[1] + Label = parts[2] + if len(Label) == 0 { + Label = Field + } } err := &Error{ - Message: chk.DefaultMessage(), + Message: Label + " " + chk.DefaultMessage(), Key: key, Name: Name, Field: Field, Value: obj, Tmpl: MessageTmpls[Name], LimitValue: chk.GetLimitValue(), + Label: Label, } v.setError(err) @@ -292,22 +308,29 @@ func (v *Validation) apply(chk Validator, obj interface{}) *Result { } } +// AddError key must like aa.bb.cc or aa.bb. // AddError adds independent error message for the provided key func (v *Validation) AddError(key, message string) { Name := key Field := "" + Label := "" parts := strings.Split(key, ".") - if len(parts) == 2 { + if len(parts) == 3 { Field = parts[0] Name = parts[1] + Label = parts[2] + if len(Label) == 0 { + Label = Field + } } err := &Error{ - Message: message, + Message: Label + " " + message, Key: key, Name: Name, Field: Field, + Label: Label, } v.setError(err) } @@ -380,7 +403,6 @@ func (v *Validation) Valid(obj interface{}) (b bool, err error) { } } - chk := Required{""}.IsSatisfied(currentField) if !hasRequired && v.RequiredFirst && !chk { if _, ok := CanSkipFuncs[vf.Name]; ok { @@ -409,7 +431,7 @@ func (v *Validation) Valid(obj interface{}) (b bool, err error) { // Step2: If pass on step1, then reflect obj's fields // Step3: Do the Recursively validation to all struct or struct pointer fields func (v *Validation) RecursiveValid(objc interface{}) (bool, error) { - //Step 1: validate obj itself firstly + // Step 1: validate obj itself firstly // fails if objc is not struct pass, err := v.Valid(objc) if err != nil || !pass { diff --git a/validation/validation_test.go b/core/validation/validation_test.go similarity index 86% rename from validation/validation_test.go rename to core/validation/validation_test.go index 3146766bed..eaa82bd2a7 100644 --- a/validation/validation_test.go +++ b/core/validation/validation_test.go @@ -253,32 +253,68 @@ func TestBase64(t *testing.T) { func TestMobile(t *testing.T) { valid := Validation{} - if valid.Mobile("19800008888", "mobile").Ok { - t.Error("\"19800008888\" is a valid mobile phone number should be false") - } - if !valid.Mobile("18800008888", "mobile").Ok { - t.Error("\"18800008888\" is a valid mobile phone number should be true") - } - if !valid.Mobile("18000008888", "mobile").Ok { - t.Error("\"18000008888\" is a valid mobile phone number should be true") - } - if !valid.Mobile("8618300008888", "mobile").Ok { - t.Error("\"8618300008888\" is a valid mobile phone number should be true") - } - if !valid.Mobile("+8614700008888", "mobile").Ok { - t.Error("\"+8614700008888\" is a valid mobile phone number should be true") - } - if !valid.Mobile("17300008888", "mobile").Ok { - t.Error("\"17300008888\" is a valid mobile phone number should be true") - } - if !valid.Mobile("+8617100008888", "mobile").Ok { - t.Error("\"+8617100008888\" is a valid mobile phone number should be true") - } - if !valid.Mobile("8617500008888", "mobile").Ok { - t.Error("\"8617500008888\" is a valid mobile phone number should be true") - } - if valid.Mobile("8617400008888", "mobile").Ok { - t.Error("\"8617400008888\" is a valid mobile phone number should be false") + validMobiles := []string{ + "19800008888", + "18800008888", + "18000008888", + "8618300008888", + "+8614700008888", + "17300008888", + "+8617100008888", + "8617500008888", + "8617400008888", + "16200008888", + "16500008888", + "16600008888", + "16700008888", + "13300008888", + "14900008888", + "15300008888", + "17300008888", + "17700008888", + "18000008888", + "18900008888", + "19100008888", + "19900008888", + "19300008888", + "13000008888", + "13100008888", + "13200008888", + "14500008888", + "15500008888", + "15600008888", + "16600008888", + "17100008888", + "17500008888", + "17600008888", + "18500008888", + "18600008888", + "13400008888", + "13500008888", + "13600008888", + "13700008888", + "13800008888", + "13900008888", + "14700008888", + "15000008888", + "15100008888", + "15200008888", + "15800008888", + "15900008888", + "17200008888", + "17800008888", + "18200008888", + "18300008888", + "18400008888", + "18700008888", + "18800008888", + "19800008888", + } + + for _, m := range validMobiles { + if !valid.Mobile(m, "mobile").Ok { + t.Error(m + " is a valid mobile phone number should be true") + } } } @@ -324,6 +360,22 @@ func TestZipCode(t *testing.T) { } } +func TestEnum(t *testing.T) { + valid := Validation{} + + if valid.Enum("sms_code", "sms|email|code", "enum").Ok { + t.Error("\"sms_code\" is in the enum list of \"sms|email|code\" should be false") + } + + if !valid.Enum("sms", "email|sms|code", "enum").Ok { + t.Error("\"sms\" is in the enum list of \"email|sms|code\" should be true") + } + + if valid.Enum(200, "code|email|sms", "enum").Ok { + t.Error("200 is in the enum list of \"code|email|sms\" should be false") + } +} + func TestValid(t *testing.T) { type user struct { ID int @@ -369,8 +421,8 @@ func TestValid(t *testing.T) { if len(valid.Errors) != 1 { t.Fatalf("valid errors len should be 1 but got %d", len(valid.Errors)) } - if valid.Errors[0].Key != "Age.Range" { - t.Errorf("Message key should be `Name.Match` but got %s", valid.Errors[0].Key) + if valid.Errors[0].Key != "Age.Range." { + t.Errorf("Message key should be `Age.Range` but got %s", valid.Errors[0].Key) } } @@ -569,5 +621,29 @@ func TestCanSkipAlso(t *testing.T) { if !b { t.Fatal("validation should be passed") } +} +func TestFieldNoEmpty(t *testing.T) { + type User struct { + Name string `json:"name" valid:"Match(/^[a-zA-Z][a-zA-Z0-9._-]{0,31}$/)"` + } + u := User{ + Name: "*", + } + + valid := Validation{} + b, err := valid.Valid(u) + if err != nil { + t.Fatal(err) + } + if b { + t.Fatal("validation should be passed") + } + if len(valid.Errors) == 0 { + t.Fatal("validation should be passed") + } + validErr := valid.Errors[0] + if len(validErr.Field) == 0 { + t.Fatal("validation should be passed") + } } diff --git a/validation/validators.go b/core/validation/validators.go similarity index 88% rename from validation/validators.go rename to core/validation/validators.go index dc18b11eb1..366b9eecdb 100644 --- a/validation/validators.go +++ b/core/validation/validators.go @@ -19,8 +19,11 @@ import ( "reflect" "regexp" "strings" + "sync" "time" "unicode/utf8" + + "github.com/beego/beego/v2/core/logs" ) // CanSkipFuncs will skip valid if RequiredFirst is true and the struct field's value is empty @@ -55,38 +58,45 @@ var MessageTmpls = map[string]string{ "Tel": "Must be valid telephone number", "Phone": "Must be valid telephone or mobile phone number", "ZipCode": "Must be valid zipcode", + "Enum": "Must be a string value in \"%s\"", } +var once sync.Once + // SetDefaultMessage set default messages // if not set, the default messages are -// "Required": "Can not be empty", -// "Min": "Minimum is %d", -// "Max": "Maximum is %d", -// "Range": "Range is %d to %d", -// "MinSize": "Minimum size is %d", -// "MaxSize": "Maximum size is %d", -// "Length": "Required length is %d", -// "Alpha": "Must be valid alpha characters", -// "Numeric": "Must be valid numeric characters", -// "AlphaNumeric": "Must be valid alpha or numeric characters", -// "Match": "Must match %s", -// "NoMatch": "Must not match %s", -// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", -// "Email": "Must be a valid email address", -// "IP": "Must be a valid ip address", -// "Base64": "Must be valid base64 characters", -// "Mobile": "Must be valid mobile number", -// "Tel": "Must be valid telephone number", -// "Phone": "Must be valid telephone or mobile phone number", -// "ZipCode": "Must be valid zipcode", +// +// "Required": "Can not be empty", +// "Min": "Minimum is %d", +// "Max": "Maximum is %d", +// "Range": "Range is %d to %d", +// "MinSize": "Minimum size is %d", +// "MaxSize": "Maximum size is %d", +// "Length": "Required length is %d", +// "Alpha": "Must be valid alpha characters", +// "Numeric": "Must be valid numeric characters", +// "AlphaNumeric": "Must be valid alpha or numeric characters", +// "Match": "Must match %s", +// "NoMatch": "Must not match %s", +// "AlphaDash": "Must be valid alpha or numeric or dash(-_) characters", +// "Email": "Must be a valid email address", +// "IP": "Must be a valid ip address", +// "Base64": "Must be valid base64 characters", +// "Mobile": "Must be valid mobile number", +// "Tel": "Must be valid telephone number", +// "Phone": "Must be valid telephone or mobile phone number", +// "ZipCode": "Must be valid zipcode", func SetDefaultMessage(msg map[string]string) { if len(msg) == 0 { return } - for name := range msg { - MessageTmpls[name] = msg[name] - } + once.Do(func() { + for name := range msg { + MessageTmpls[name] = msg[name] + } + }) + logs.Warn(`you must SetDefaultMessage at once`) } // Validator interface @@ -632,7 +642,7 @@ func (b Base64) GetLimitValue() interface{} { } // just for chinese mobile phone number -var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][01356789]|[4][579]))\d{8}$`) +var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?1([356789][0-9]|4[579]|6[67]|7[0135678]|9[189])[0-9]{8}$`) // Mobile check struct type Mobile struct { @@ -729,3 +739,37 @@ func (z ZipCode) GetKey() string { func (z ZipCode) GetLimitValue() interface{} { return nil } + +// Enum Requires that the field must be within the enumerated value +type Enum struct { + Rules string + Key string +} + +// IsSatisfied judge whether obj is valid +func (e Enum) IsSatisfied(i interface{}) bool { + if val, ok := i.(string); ok { + roles := strings.Split(e.Rules, "|") + for _, v := range roles { + if val == strings.TrimSpace(v) { + return true + } + } + } + return false +} + +// DefaultMessage return the default Enum error message +func (e Enum) DefaultMessage() string { + return fmt.Sprintf(MessageTmpls["Enum"], e.Rules) +} + +// GetKey return the e.Key +func (e Enum) GetKey() string { + return e.Key +} + +// GetLimitValue return nil now +func (Enum) GetLimitValue() interface{} { + return nil +} diff --git a/doc.go b/doc.go index 8825bd299e..6975885ab0 100644 --- a/doc.go +++ b/doc.go @@ -1,17 +1,15 @@ -/* -Package beego provide a MVC framework -beego: an open-source, high-performance, modular, full-stack web framework +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -It is used for rapid development of RESTful APIs, web apps and backend services in Go. -beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. - - package main - import "github.com/astaxie/beego" - - func main() { - beego.Run() - } - -more information: http://beego.me -*/ package beego diff --git a/filter.go b/filter.go deleted file mode 100644 index 9cc6e9134f..0000000000 --- a/filter.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import "github.com/astaxie/beego/context" - -// FilterFunc defines a filter function which is invoked before the controller handler is executed. -type FilterFunc func(*context.Context) - -// FilterRouter defines a filter operation which is invoked before the controller handler is executed. -// It can match the URL against a pattern, and execute a filter function -// when a request with a matching URL arrives. -type FilterRouter struct { - filterFunc FilterFunc - tree *Tree - pattern string - returnOnOutput bool - resetParams bool -} - -// ValidRouter checks if the current request is matched by this filter. -// If the request is matched, the values of the URL parameters defined -// by the filter pattern are also returned. -func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { - isOk := f.tree.Match(url, ctx) - if isOk != nil { - if b, ok := isOk.(bool); ok { - return b - } - } - return false -} diff --git a/go.mod b/go.mod index fbdec124da..d80bec1112 100644 --- a/go.mod +++ b/go.mod @@ -1,39 +1,87 @@ -module github.com/astaxie/beego +module github.com/beego/beego/v2 + +go 1.20 require ( - github.com/Knetic/govaluate v3.0.0+incompatible // indirect - github.com/OwnLocal/goes v1.0.0 - github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd + github.com/DATA-DOG/go-sqlmock v1.5.1 github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 - github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 - github.com/casbin/casbin v1.7.0 + github.com/bits-and-blooms/bloom/v3 v3.6.0 + github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b + github.com/casbin/casbin v1.9.1 github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 - github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb - github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c // indirect - github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a // indirect - github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 // indirect - github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 // indirect - github.com/elazarl/go-bindata-assetfs v1.0.0 - github.com/go-redis/redis v6.14.2+incompatible - github.com/go-sql-driver/mysql v1.4.1 - github.com/gogo/protobuf v1.1.1 - github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect + github.com/couchbase/go-couchbase v0.1.0 + github.com/elastic/go-elasticsearch/v6 v6.8.10 + github.com/elazarl/go-bindata-assetfs v1.0.1 + github.com/go-kit/kit v0.12.1-0.20220826005032-a7ba4fa4e289 + github.com/go-kit/log v0.2.1 + github.com/go-sql-driver/mysql v1.8.1 + github.com/gogo/protobuf v1.3.2 github.com/gomodule/redigo v2.0.0+incompatible - github.com/lib/pq v1.0.0 - github.com/mattn/go-sqlite3 v1.10.0 - github.com/pelletier/go-toml v1.2.0 // indirect - github.com/pkg/errors v0.8.0 // indirect - github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 // indirect - github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373 - github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d // indirect + github.com/google/uuid v1.6.0 + github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 + github.com/hashicorp/golang-lru v0.5.4 + github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 + github.com/lib/pq v1.10.5 + github.com/mattn/go-sqlite3 v1.14.22 + github.com/mitchellh/mapstructure v1.5.0 + github.com/opentracing/opentracing-go v1.2.0 + github.com/pelletier/go-toml v1.9.2 + github.com/prometheus/client_golang v1.19.0 + github.com/redis/go-redis/v9 v9.5.1 + github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18 github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec - github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c // indirect - github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b // indirect - golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 - golang.org/x/net v0.0.0-20181114220301-adae6a3d119a // indirect - gopkg.in/yaml.v2 v2.2.1 + github.com/stretchr/testify v1.9.0 + github.com/valyala/bytebufferpool v1.0.0 + go.etcd.io/etcd/client/v3 v3.5.9 + go.opentelemetry.io/otel v1.11.2 + go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.11.2 + go.opentelemetry.io/otel/sdk v1.11.2 + go.opentelemetry.io/otel/trace v1.11.2 + golang.org/x/crypto v0.24.0 + golang.org/x/sync v0.7.0 + google.golang.org/grpc v1.63.0 + google.golang.org/protobuf v1.34.2 + gopkg.in/yaml.v3 v3.0.1 ) -replace golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 => github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.10.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/coreos/go-semver v0.3.0 // indirect + github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/couchbase/gomemcached v0.1.3 // indirect + github.com/couchbase/goutils v0.1.0 // indirect + github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/edsrzf/mmap-go v1.0.0 // indirect + github.com/go-logfmt/logfmt v0.5.1 // indirect + github.com/go-logr/logr v1.2.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.5.0 // indirect + github.com/prometheus/common v0.48.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect + github.com/siddontang/go v0.0.0-20170517070808-cb568a3e5cc0 // indirect + github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d // indirect + github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112 // indirect + go.etcd.io/etcd/api/v3 v3.5.9 // indirect + go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.7.0 // indirect + go.uber.org/zap v1.19.1 // indirect + golang.org/x/net v0.23.0 // indirect + golang.org/x/sys v0.21.0 // indirect + golang.org/x/text v0.16.0 // indirect + google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect +) -replace gopkg.in/yaml.v2 v2.2.1 => github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d +replace github.com/gomodule/redigo => github.com/gomodule/redigo v1.8.8 diff --git a/go.sum b/go.sum index ab23316219..ce02dfe582 100644 --- a/go.sum +++ b/go.sum @@ -1,68 +1,257 @@ -github.com/Knetic/govaluate v3.0.0+incompatible h1:7o6+MAPhYTCF0+fdvoz1xDedhRb4f6s9Tn1Tt7/WTEg= -github.com/Knetic/govaluate v3.0.0+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= -github.com/OwnLocal/goes v1.0.0/go.mod h1:8rIFjBGTue3lCU0wplczcUgt9Gxgrkkrw7etMIcn8TM= -github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd h1:jZtX5jh5IOMu0fpOTC3ayh6QGSPJ/KWOv1lgPvbRw1M= -github.com/beego/goyaml2 v0.0.0-20130207012346-5545475820dd/go.mod h1:1b+Y/CofkYwXMUU0OhQqGvsY2Bvgr4j6jfT699wyZKQ= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo= +github.com/DATA-DOG/go-sqlmock v1.5.1/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis v2.5.0+incompatible/go.mod h1:8HZjEj4yU0dwhYHky+DxYx+6BMjkBbe5ONFIF1MXffk= github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542 h1:nYXb+3jF6Oq/j8R/y90XrKpreCxIalBWfeyeKymgOPk= github.com/beego/x2j v0.0.0-20131220205130-a0352aadc542/go.mod h1:kSeGC/p1AbBiEp5kat81+DSQrZenVBZXklMLaELspWU= -github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff h1:/kO0p2RTGLB8R5gub7ps0GmYpB2O8LXEoPq8tzFDCUI= -github.com/belogik/goes v0.0.0-20151229125003-e54d722c3aff/go.mod h1:PhH1ZhyCzHKt4uAasyx+ljRCgoezetRNf59CUtwUkqY= -github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 h1:rRISKWyXfVxvoa702s91Zl5oREZTrR3yv+tXrrX7G/g= -github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= -github.com/casbin/casbin v1.7.0 h1:PuzlE8w0JBg/DhIqnkF1Dewf3z+qmUZMVN07PonvVUQ= -github.com/casbin/casbin v1.7.0/go.mod h1:c67qKN6Oum3UF5Q1+BByfFxkwKvhwW57ITjqwtzR1KE= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.10.0 h1:ePXTeiPEazB5+opbv5fr8umg2R/1NlzgDsyepwsSr88= +github.com/bits-and-blooms/bitset v1.10.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bloom/v3 v3.6.0 h1:dTU0OVLJSoOhz9m68FTXMFfA39nR8U/nTCs1zb26mOI= +github.com/bits-and-blooms/bloom/v3 v3.6.0/go.mod h1:VKlUSvp0lFIYqxJjzdnSsZEw4iHb1kOL2tfHTgyJBHg= +github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b h1:L/QXpzIa3pOvUGt1D1lA5KjYhPBAN/3iWdP7xeFS9F0= +github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/casbin/casbin v1.9.1 h1:ucjbS5zTrmSLtH4XogqOG920Poe6QatdXtz1FEbApeM= +github.com/casbin/casbin v1.9.1/go.mod h1:z8uPsfBJGUsnkagrt3G8QvjgTKFMBJ32UP8HpZllfog= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58 h1:F1EaeKL/ta07PY/k9Os/UFtwERei2/XzGemhpGnBKNg= github.com/cloudflare/golz4 v0.0.0-20150217214814-ef862a3cdc58/go.mod h1:EOBUe0h4xcZ5GoxqC5SDxFQ8gwyZPKQoEzownBlhI80= -github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb h1:w3RapLhkA5+km9Z8vUkC6VCaskduJXvXwJg5neKnfDU= -github.com/couchbase/go-couchbase v0.0.0-20181122212707-3e9b6e1258bb/go.mod h1:TWI8EKQMs5u5jLKW/tsb9VwauIrMIxQG1r5fMsswK5U= -github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c h1:K4FIibkr4//ziZKOKmt4RL0YImuTjLLBtwElf+F2lSQ= -github.com/couchbase/gomemcached v0.0.0-20181122193126-5125a94a666c/go.mod h1:srVSlQLB8iXBVXHgnqemxUXqN6FCvClgCMPCsjBDR7c= -github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a h1:Y5XsLCEhtEI8qbD9RP3Qlv5FXdTDHxZM9UPUnMRgBp8= -github.com/couchbase/goutils v0.0.0-20180530154633-e865a1461c8a/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs= +github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= +github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-systemd/v22 v22.3.2 h1:D9/bQk5vlXQFZ6Kwuu6zaiXJ9oTPe68++AzAJc1DzSI= +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/couchbase/go-couchbase v0.1.0 h1:g4bCvDwRL+ZL6HLhYeRlXxEYP31Wpy0VFxnFw6efEp8= +github.com/couchbase/go-couchbase v0.1.0/go.mod h1:+/bddYDxXsf9qt0xpDUtRR47A2GjaXmGGAqQ/k3GJ8A= +github.com/couchbase/gomemcached v0.1.3 h1:HIc5qMYNbuhB7zNaiEtj61DCYkquAwrQlf64q7JzdEY= +github.com/couchbase/gomemcached v0.1.3/go.mod h1:mxliKQxOv84gQ0bJWbI+w9Wxdpt9HjDvgW9MjCym5Vo= +github.com/couchbase/goutils v0.1.0 h1:0WLlKJilu7IBm98T8nS9+J36lBFVLRUSIUtyD/uWpAE= +github.com/couchbase/goutils v0.1.0/go.mod h1:BQwMFlJzDjFDG3DJUdU0KORxn88UlsOULuxLExMh3Hs= github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76 h1:Lgdd/Qp96Qj8jqLpq2cI1I1X7BJnu06efS+XkhRoLUQ= github.com/cupcake/rdb v0.0.0-20161107195141-43ba34106c76/go.mod h1:vYwsqCOLxGiisLwp9rITslkFNpZD5rz43tf41QFkTWY= -github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712 h1:aaQcKT9WumO6JEJcRyTqFVq4XUZiUcKR2/GI31TOcz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/edsrzf/mmap-go v0.0.0-20170320065105-0bce6a688712/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= -github.com/elazarl/go-bindata-assetfs v1.0.0 h1:G/bYguwHIzWq9ZoyUQqrjTmJbbYn3j3CKKpKinvZLFk= -github.com/elazarl/go-bindata-assetfs v1.0.0/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= -github.com/go-redis/redis v6.14.2+incompatible h1:UE9pLhzmWf+xHNmZsoccjXosPicuiNaInPgym8nzfg0= -github.com/go-redis/redis v6.14.2+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= -github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d h1:xy93KVe+KrIIwWDEAfQBdIfsiHJkepbYsDr+VY3g9/o= -github.com/go-yaml/yaml v0.0.0-20180328195020-5420a8b6744d/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -github.com/gogo/protobuf v1.1.1 h1:72R+M5VuhED/KujmZVcIquuo8mBgX4oVda//DQb3PXo= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85 h1:B7ZbAFz7NOmvpUE5RGtu3u0WIizy5GdvbNpEf4RPnWs= -github.com/golang/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:uZvAcrsnNaCxlh1HorK5dUQHGmEKPh2H/Rl1kehswPo= +github.com/edsrzf/mmap-go v1.0.0 h1:CEBF7HpRnUCSJgGUb5h1Gm7e3VkmVDrR8lvWVLtrOFw= +github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= +github.com/elastic/go-elasticsearch/v6 v6.8.10 h1:2lN0gJ93gMBXvkhwih5xquldszpm8FlUwqG5sPzr6a8= +github.com/elastic/go-elasticsearch/v6 v6.8.10/go.mod h1:UwaDJsD3rWLM5rKNFzv9hgox93HoX8utj1kxD9aFUcI= +github.com/elazarl/go-bindata-assetfs v1.0.1 h1:m0kkaHRKEu7tUIUFVwhGGGYClXvyl4RE03qmvRTNfbw= +github.com/elazarl/go-bindata-assetfs v1.0.1/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/glendc/gopher-json v0.0.0-20170414221815-dc4743023d0c/go.mod h1:Gja1A+xZ9BoviGJNA2E9vFkPjjsl+CoJxSXiQM1UXtw= +github.com/go-kit/kit v0.12.1-0.20220826005032-a7ba4fa4e289 h1:468Nv6YtYO38Z+pFL6fRSVILGfdTFPei9ksVneiUHUc= +github.com/go-kit/kit v0.12.1-0.20220826005032-a7ba4fa4e289/go.mod h1:phqEHMMUbyrCFCTgH48JueqrM3md2HcAZ8N3XE4FKDg= +github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= +github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= +github.com/go-logfmt/logfmt v0.5.1 h1:otpy5pqBCBZ1ng9RQ0dPu4PN7ba75Y/aA+UpowDyNVA= +github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= -github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= -github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= -github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 h1:xT+JlYxNGqyT+XcU8iUrN18JYed2TvG9yN5ULG2jATM= -github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw= -github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373 h1:p6IxqQMjab30l4lb9mmkIkkcE1yv6o0SKbPhW5pxqHI= -github.com/siddontang/ledisdb v0.0.0-20181029004158-becf5f38d373/go.mod h1:mF1DpOSOUiJRMR+FDqaqu3EBqrybQtrDDszLUZ6oxPg= +github.com/gomodule/redigo v1.8.8 h1:f6cXq6RRfiyrOJEV7p3JhLDlmawGBVBBP1MggY8Mo4E= +github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6 h1:wxyqOzKxsRJ6vVRL9sXQ64Z45wmBuQ+OTH9sLsC5rKc= +github.com/ledisdb/ledisdb v0.0.0-20200510135210-d35789ec47e6/go.mod h1:n931TsDuKuq+uX4v1fulaMbA/7ZLLhjc85h7chZGBCQ= +github.com/lib/pq v1.10.5 h1:J+gdV2cUmX7ZqL2B0lFcW0m+egaHC2V3lpO8nWxyYiQ= +github.com/lib/pq v1.10.5/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU= +github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= +github.com/onsi/gomega v1.7.1 h1:K0jcRCwNQM3vFGh1ppMtDh/+7ApJrjldlX8fA0jDTLQ= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pelletier/go-toml v1.0.1/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pelletier/go-toml v1.9.2 h1:7NiByeVF4jKSG1lDF3X8LTIkq2/bu+1uYbIm1eS5tzk= +github.com/pelletier/go-toml v1.9.2/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/peterh/liner v1.0.1-0.20171122030339-3681c2a91233/go.mod h1:xIteQHvHuaLYG9IFj6mSxM0fCKrs34IrEQUhOYuGPHc= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= +github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= +github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/redis/go-redis/v9 v9.5.1 h1:H1X4D3yHPaYrkL5X06Wh6xNVM/pX0Ft4RV0vMGvLBh8= +github.com/redis/go-redis/v9 v9.5.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18 h1:DAYUYH5869yV94zvCES9F51oYtN5oGlwjxJJz7ZCnik= +github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18/go.mod h1:nkxAfR/5quYxwPZhyDxgasBMnRtBZd0FCEpawpjMUFg= +github.com/siddontang/go v0.0.0-20170517070808-cb568a3e5cc0 h1:QIF48X1cihydXibm+4wfAc0r/qyPyuFiPFRNphdMpEE= +github.com/siddontang/go v0.0.0-20170517070808-cb568a3e5cc0/go.mod h1:3yhqj7WBBfRhbBlzyOC3gUxftwsU0u8gqevxwIHQpMw= +github.com/siddontang/goredis v0.0.0-20150324035039-760763f78400/go.mod h1:DDcKzU3qCuvj/tPnimWSsZZzvk9qvkvrIL5naVBPh5s= github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d h1:NVwnfyR3rENtlz62bcrkXME3INVUa4lcdGt+opvxExs= github.com/siddontang/rdb v0.0.0-20150307021120-fc89ed2e418d/go.mod h1:AMEsy7v5z92TR1JKMkLLoaOQk++LVnOKL3ScbJ8GNGA= github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec h1:q6XVwXmKvCRHRqesF3cSv6lNqqHi0QWOvgDlSohg8UA= github.com/ssdb/gossdb v0.0.0-20180723034631-88f6b59b84ec/go.mod h1:QBvMkMya+gXctz3kmljlUCu/yB3GZ6oee+dUozsezQE= -github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c h1:3eGShk3EQf5gJCYW+WzA0TEJQd37HLOmlYF7N0YJwv0= -github.com/syndtr/goleveldb v0.0.0-20181127023241-353a9fca669c/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0= -github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b h1:0Ve0/CCjiAiyKddUMUn3RwIGlq2iTW4GuVzyoKBYO/8= -github.com/wendal/errors v0.0.0-20130201093226-f66c77a7882b/go.mod h1:Q12BUT7DqIlHRmgv3RskH+UCM/4eqVMgI0EMmlSpAXc= -golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85 h1:et7+NAX3lLIk5qUCTA9QelBjGE/NkhzYw/mhnr0s7nI= -golang.org/x/crypto v0.0.0-20181127143415-eb0de9b17e85/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a h1:gOpx8G595UYyvj8UK4+OFyY4rx037g3fmfhe5SasG3U= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112 h1:NBrpnvz0pDPf3+HXZ1C9GcJd1DTpWDLcLWZhNq6uP7o= +github.com/syndtr/goleveldb v0.0.0-20160425020131-cfa635847112/go.mod h1:Z4AUp2Km+PwemOoO/VB5AOx9XSsIItzFjoJlOSiYmn0= +github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= +github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= +github.com/ugorji/go v0.0.0-20171122102828-84cb69a8af83/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/gopher-lua v0.0.0-20171031051903-609c9cd26973/go.mod h1:aEV29XrmTYFr3CiRxZeGHpkvbwq+prZduBqMaascyCU= +go.etcd.io/etcd/api/v3 v3.5.9 h1:4wSsluwyTbGGmyjJktOf3wFQoTBIURXHnq9n/G/JQHs= +go.etcd.io/etcd/api/v3 v3.5.9/go.mod h1:uyAal843mC8uUVSLWz6eHa/d971iDGnCRpmKd2Z+X8k= +go.etcd.io/etcd/client/pkg/v3 v3.5.9 h1:oidDC4+YEuSIQbsR94rY9gur91UPL6DnxDCIYd2IGsE= +go.etcd.io/etcd/client/pkg/v3 v3.5.9/go.mod h1:y+CzeSmkMpWN2Jyu1npecjB9BBnABxGM4pN8cGuJeL4= +go.etcd.io/etcd/client/v3 v3.5.9 h1:r5xghnU7CwbUxD/fbUtRyJGaYNfDun8sp/gTr1hew6E= +go.etcd.io/etcd/client/v3 v3.5.9/go.mod h1:i/Eo5LrZ5IKqpbtpPDuaUnDOUv471oDg8cjQaUr2MbA= +go.opentelemetry.io/otel v1.11.2 h1:YBZcQlsVekzFsFbjygXMOXSs6pialIZxcjfO/mBDmR0= +go.opentelemetry.io/otel v1.11.2/go.mod h1:7p4EUV+AqgdlNV9gL97IgUZiVR3yrFXYo53f9BM3tRI= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.11.2 h1:BhEVgvuE1NWLLuMLvC6sif791F45KFHi5GhOs1KunZU= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.11.2/go.mod h1:bx//lU66dPzNT+Y0hHA12ciKoMOH9iixEwCqC1OeQWQ= +go.opentelemetry.io/otel/sdk v1.11.2 h1:GF4JoaEx7iihdMFu30sOyRx52HDHOkl9xQ8SMqNXUiU= +go.opentelemetry.io/otel/sdk v1.11.2/go.mod h1:wZ1WxImwpq+lVRo4vsmSOxdd+xwoUJ6rqyLc3SyX9aU= +go.opentelemetry.io/otel/trace v1.11.2 h1:Xf7hWSF2Glv0DE3MH7fBHvtpSBsjcBUe5MYAmZM/+y0= +go.opentelemetry.io/otel/trace v1.11.2/go.mod h1:4N+yC7QEz7TTsG9BSRLNAa63eg5E06ObSbKPmxQ/pKA= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.11-0.20210813005559-691160354723 h1:sHOAIxRGBp443oHZIPB+HsUGaksVCXVQENPxwTfQdH4= +go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= +go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/zap v1.19.1 h1:ue41HOKd1vGURxrmeKIgELGb3jPW9DMUDGtsinblHwI= +go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de h1:F6qOa9AZTYJXOUEr4jDysRDLrm4PHePlge4v4TGAlxY= +google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:VUhTRKeHn9wwcdrk73nvdC9gF178Tzhmt/qyaFcPLSo= +google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de h1:jFNzHPIeuzhdRwVhbZdiym9q0ory/xY3sA+v2wPg8I0= +google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:5iCWqnniDlqZHrd3neWVTOwvh/v6s3232omMecelax8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY= +google.golang.org/grpc v1.63.0 h1:WjKe+dnvABXyPJMD7KDNLxtoGk5tgk+YFWN6cBWjZE8= +google.golang.org/grpc v1.63.0/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/httplib/httplib_test.go b/httplib/httplib_test.go deleted file mode 100644 index 7314ae01cf..0000000000 --- a/httplib/httplib_test.go +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package httplib - -import ( - "io/ioutil" - "net" - "net/http" - "os" - "strings" - "testing" - "time" -) - -func TestResponse(t *testing.T) { - req := Get("http://httpbin.org/get") - resp, err := req.Response() - if err != nil { - t.Fatal(err) - } - t.Log(resp) -} - -func TestGet(t *testing.T) { - req := Get("http://httpbin.org/get") - b, err := req.Bytes() - if err != nil { - t.Fatal(err) - } - t.Log(b) - - s, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(s) - - if string(b) != s { - t.Fatal("request data not match") - } -} - -func TestSimplePost(t *testing.T) { - v := "smallfish" - req := Post("http://httpbin.org/post") - req.Param("username", v) - - str, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in post") - } -} - -//func TestPostFile(t *testing.T) { -// v := "smallfish" -// req := Post("http://httpbin.org/post") -// req.Debug(true) -// req.Param("username", v) -// req.PostFile("uploadfile", "httplib_test.go") - -// str, err := req.String() -// if err != nil { -// t.Fatal(err) -// } -// t.Log(str) - -// n := strings.Index(str, v) -// if n == -1 { -// t.Fatal(v + " not found in post") -// } -//} - -func TestSimplePut(t *testing.T) { - str, err := Put("http://httpbin.org/put").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestSimpleDelete(t *testing.T) { - str, err := Delete("http://httpbin.org/delete").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestSimpleDeleteParam(t *testing.T) { - str, err := Delete("http://httpbin.org/delete").Param("key", "val").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} - -func TestWithCookie(t *testing.T) { - v := "smallfish" - str, err := Get("http://httpbin.org/cookies/set?k1=" + v).SetEnableCookie(true).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - str, err = Get("http://httpbin.org/cookies").SetEnableCookie(true).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in cookie") - } -} - -func TestWithBasicAuth(t *testing.T) { - str, err := Get("http://httpbin.org/basic-auth/user/passwd").SetBasicAuth("user", "passwd").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - n := strings.Index(str, "authenticated") - if n == -1 { - t.Fatal("authenticated not found in response") - } -} - -func TestWithUserAgent(t *testing.T) { - v := "beego" - str, err := Get("http://httpbin.org/headers").SetUserAgent(v).String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in user-agent") - } -} - -func TestWithSetting(t *testing.T) { - v := "beego" - var setting BeegoHTTPSettings - setting.EnableCookie = true - setting.UserAgent = v - setting.Transport = &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 50, - IdleConnTimeout: 90 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - setting.ReadWriteTimeout = 5 * time.Second - SetDefaultSetting(setting) - - str, err := Get("http://httpbin.org/get").String() - if err != nil { - t.Fatal(err) - } - t.Log(str) - - n := strings.Index(str, v) - if n == -1 { - t.Fatal(v + " not found in user-agent") - } -} - -func TestToJson(t *testing.T) { - req := Get("http://httpbin.org/ip") - resp, err := req.Response() - if err != nil { - t.Fatal(err) - } - t.Log(resp) - - // httpbin will return http remote addr - type IP struct { - Origin string `json:"origin"` - } - var ip IP - err = req.ToJSON(&ip) - if err != nil { - t.Fatal(err) - } - t.Log(ip.Origin) - ips := strings.Split(ip.Origin, ",") - if len(ips) == 0 { - t.Fatal("response is not valid ip") - } - for i := range ips { - if net.ParseIP(strings.TrimSpace(ips[i])).To4() == nil { - t.Fatal("response is not valid ip") - } - } - -} - -func TestToFile(t *testing.T) { - f := "beego_testfile" - req := Get("http://httpbin.org/ip") - err := req.ToFile(f) - if err != nil { - t.Fatal(err) - } - defer os.Remove(f) - b, err := ioutil.ReadFile(f) - if n := strings.Index(string(b), "origin"); n == -1 { - t.Fatal(err) - } -} - -func TestHeader(t *testing.T) { - req := Get("http://httpbin.org/headers") - req.Header("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/31.0.1650.57 Safari/537.36") - str, err := req.String() - if err != nil { - t.Fatal(err) - } - t.Log(str) -} diff --git a/log.go b/log.go deleted file mode 100644 index cc4c0f81ab..0000000000 --- a/log.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "strings" - - "github.com/astaxie/beego/logs" -) - -// Log levels to control the logging output. -// Deprecated: use github.com/astaxie/beego/logs instead. -const ( - LevelEmergency = iota - LevelAlert - LevelCritical - LevelError - LevelWarning - LevelNotice - LevelInformational - LevelDebug -) - -// BeeLogger references the used application logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -var BeeLogger = logs.GetBeeLogger() - -// SetLevel sets the global log level used by the simple logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLevel(l int) { - logs.SetLevel(l) -} - -// SetLogFuncCall set the CallDepth, default is 3 -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLogFuncCall(b bool) { - logs.SetLogFuncCall(b) -} - -// SetLogger sets a new logger. -// Deprecated: use github.com/astaxie/beego/logs instead. -func SetLogger(adaptername string, config string) error { - return logs.SetLogger(adaptername, config) -} - -// Emergency logs a message at emergency level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Emergency(v ...interface{}) { - logs.Emergency(generateFmtStr(len(v)), v...) -} - -// Alert logs a message at alert level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Alert(v ...interface{}) { - logs.Alert(generateFmtStr(len(v)), v...) -} - -// Critical logs a message at critical level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Critical(v ...interface{}) { - logs.Critical(generateFmtStr(len(v)), v...) -} - -// Error logs a message at error level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Error(v ...interface{}) { - logs.Error(generateFmtStr(len(v)), v...) -} - -// Warning logs a message at warning level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Warning(v ...interface{}) { - logs.Warning(generateFmtStr(len(v)), v...) -} - -// Warn compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Warn(v ...interface{}) { - logs.Warn(generateFmtStr(len(v)), v...) -} - -// Notice logs a message at notice level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Notice(v ...interface{}) { - logs.Notice(generateFmtStr(len(v)), v...) -} - -// Informational logs a message at info level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Informational(v ...interface{}) { - logs.Informational(generateFmtStr(len(v)), v...) -} - -// Info compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Info(v ...interface{}) { - logs.Info(generateFmtStr(len(v)), v...) -} - -// Debug logs a message at debug level. -// Deprecated: use github.com/astaxie/beego/logs instead. -func Debug(v ...interface{}) { - logs.Debug(generateFmtStr(len(v)), v...) -} - -// Trace logs a message at trace level. -// compatibility alias for Warning() -// Deprecated: use github.com/astaxie/beego/logs instead. -func Trace(v ...interface{}) { - logs.Trace(generateFmtStr(len(v)), v...) -} - -func generateFmtStr(n int) string { - return strings.Repeat("%v ", n) -} diff --git a/logs/es/es.go b/logs/es/es.go deleted file mode 100644 index 9d6a615c27..0000000000 --- a/logs/es/es.go +++ /dev/null @@ -1,81 +0,0 @@ -package es - -import ( - "encoding/json" - "errors" - "fmt" - "net" - "net/url" - "time" - - "github.com/OwnLocal/goes" - "github.com/astaxie/beego/logs" -) - -// NewES return a LoggerInterface -func NewES() logs.Logger { - cw := &esLogger{ - Level: logs.LevelDebug, - } - return cw -} - -type esLogger struct { - *goes.Client - DSN string `json:"dsn"` - Level int `json:"level"` -} - -// {"dsn":"http://localhost:9200/","level":1} -func (el *esLogger) Init(jsonconfig string) error { - err := json.Unmarshal([]byte(jsonconfig), el) - if err != nil { - return err - } - if el.DSN == "" { - return errors.New("empty dsn") - } else if u, err := url.Parse(el.DSN); err != nil { - return err - } else if u.Path == "" { - return errors.New("missing prefix") - } else if host, port, err := net.SplitHostPort(u.Host); err != nil { - return err - } else { - conn := goes.NewClient(host, port) - el.Client = conn - } - return nil -} - -// WriteMsg will write the msg and level into es -func (el *esLogger) WriteMsg(when time.Time, msg string, level int) error { - if level > el.Level { - return nil - } - - vals := make(map[string]interface{}) - vals["@timestamp"] = when.Format(time.RFC3339) - vals["@msg"] = msg - d := goes.Document{ - Index: fmt.Sprintf("%04d.%02d.%02d", when.Year(), when.Month(), when.Day()), - Type: "logs", - Fields: vals, - } - _, err := el.Index(d, nil) - return err -} - -// Destroy is a empty method -func (el *esLogger) Destroy() { - -} - -// Flush is a empty method -func (el *esLogger) Flush() { - -} - -func init() { - logs.Register(logs.AdapterEs, NewES) -} - diff --git a/logs/slack.go b/logs/slack.go deleted file mode 100644 index 1cd2e5aeeb..0000000000 --- a/logs/slack.go +++ /dev/null @@ -1,60 +0,0 @@ -package logs - -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "time" -) - -// SLACKWriter implements beego LoggerInterface and is used to send jiaoliao webhook -type SLACKWriter struct { - WebhookURL string `json:"webhookurl"` - Level int `json:"level"` -} - -// newSLACKWriter create jiaoliao writer. -func newSLACKWriter() Logger { - return &SLACKWriter{Level: LevelTrace} -} - -// Init SLACKWriter with json config string -func (s *SLACKWriter) Init(jsonconfig string) error { - return json.Unmarshal([]byte(jsonconfig), s) -} - -// WriteMsg write message in smtp writer. -// it will send an email with subject and only this message. -func (s *SLACKWriter) WriteMsg(when time.Time, msg string, level int) error { - if level > s.Level { - return nil - } - - text := fmt.Sprintf("{\"text\": \"%s %s\"}", when.Format("2006-01-02 15:04:05"), msg) - - form := url.Values{} - form.Add("payload", text) - - resp, err := http.PostForm(s.WebhookURL, form) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("Post webhook failed %s %d", resp.Status, resp.StatusCode) - } - return nil -} - -// Flush implementing method. empty. -func (s *SLACKWriter) Flush() { -} - -// Destroy implementing method. empty. -func (s *SLACKWriter) Destroy() { -} - -func init() { - Register(AdapterSlack, newSLACKWriter) -} diff --git a/namespace_test.go b/namespace_test.go deleted file mode 100644 index b3f20dff22..0000000000 --- a/namespace_test.go +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "net/http" - "net/http/httptest" - "strconv" - "testing" - - "github.com/astaxie/beego/context" -) - -func TestNamespaceGet(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/user", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Get("/user", func(ctx *context.Context) { - ctx.Output.Body([]byte("v1_user")) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "v1_user" { - t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespacePost(t *testing.T) { - r, _ := http.NewRequest("POST", "/v1/user/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Post("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceNest(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/admin/order", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Namespace( - NewNamespace("/admin"). - Get("/order", func(ctx *context.Context) { - ctx.Output.Body([]byte("order")) - }), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "order" { - t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceNestParam(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Namespace( - NewNamespace("/admin"). - Get("/order/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceRouter(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/api/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Router("/api/list", &TestController{}, "*:List") - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceAutoFunc(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/test/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.AutoRouter(&TestController{}) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "i am list" { - t.Errorf("user define func can't run") - } -} - -func TestNamespaceFilter(t *testing.T) { - r, _ := http.NewRequest("GET", "/v1/user/123", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v1") - ns.Filter("before", func(ctx *context.Context) { - ctx.Output.Body([]byte("this is Filter")) - }). - Get("/user/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "this is Filter" { - t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String()) - } -} - -func TestNamespaceCond(t *testing.T) { - r, _ := http.NewRequest("GET", "/v2/test/list", nil) - w := httptest.NewRecorder() - - ns := NewNamespace("/v2") - ns.Cond(func(ctx *context.Context) bool { - return ctx.Input.Domain() == "beego.me" - }). - AutoRouter(&TestController{}) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Code != 405 { - t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code)) - } -} - -func TestNamespaceInside(t *testing.T) { - r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil) - w := httptest.NewRecorder() - ns := NewNamespace("/v3", - NSAutoRouter(&TestController{}), - NSNamespace("/shop", - NSGet("/order/:id", func(ctx *context.Context) { - ctx.Output.Body([]byte(ctx.Input.Param(":id"))) - }), - ), - ) - AddNamespace(ns) - BeeApp.Handlers.ServeHTTP(w, r) - if w.Body.String() != "123" { - t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) - } -} diff --git a/orm/cmd_utils.go b/orm/cmd_utils.go deleted file mode 100644 index 7c9aa51ed1..0000000000 --- a/orm/cmd_utils.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "strings" -) - -type dbIndex struct { - Table string - Name string - SQL string -} - -// create database drop sql. -func getDbDropSQL(al *alias) (sqls []string) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - - for _, mi := range modelCache.allOrdered() { - sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q)) - } - return sqls -} - -// get database column type string. -func getColumnTyp(al *alias, fi *fieldInfo) (col string) { - T := al.DbBaser.DbTypes() - fieldType := fi.fieldType - fieldSize := fi.size - -checkColumn: - switch fieldType { - case TypeBooleanField: - col = T["bool"] - case TypeVarCharField: - if al.Driver == DRPostgres && fi.toText { - col = T["string-text"] - } else { - col = fmt.Sprintf(T["string"], fieldSize) - } - case TypeCharField: - col = fmt.Sprintf(T["string-char"], fieldSize) - case TypeTextField: - col = T["string-text"] - case TypeTimeField: - col = T["time.Time-clock"] - case TypeDateField: - col = T["time.Time-date"] - case TypeDateTimeField: - col = T["time.Time"] - case TypeBitField: - col = T["int8"] - case TypeSmallIntegerField: - col = T["int16"] - case TypeIntegerField: - col = T["int32"] - case TypeBigIntegerField: - if al.Driver == DRSqlite { - fieldType = TypeIntegerField - goto checkColumn - } - col = T["int64"] - case TypePositiveBitField: - col = T["uint8"] - case TypePositiveSmallIntegerField: - col = T["uint16"] - case TypePositiveIntegerField: - col = T["uint32"] - case TypePositiveBigIntegerField: - col = T["uint64"] - case TypeFloatField: - col = T["float64"] - case TypeDecimalField: - s := T["float64-decimal"] - if !strings.Contains(s, "%d") { - col = s - } else { - col = fmt.Sprintf(s, fi.digits, fi.decimals) - } - case TypeJSONField: - if al.Driver != DRPostgres { - fieldType = TypeVarCharField - goto checkColumn - } - col = T["json"] - case TypeJsonbField: - if al.Driver != DRPostgres { - fieldType = TypeVarCharField - goto checkColumn - } - col = T["jsonb"] - case RelForeignKey, RelOneToOne: - fieldType = fi.relModelInfo.fields.pk.fieldType - fieldSize = fi.relModelInfo.fields.pk.size - goto checkColumn - } - - return -} - -// create alter sql string. -func getColumnAddQuery(al *alias, fi *fieldInfo) string { - Q := al.DbBaser.TableQuote() - typ := getColumnTyp(al, fi) - - if !fi.null { - typ += " " + "NOT NULL" - } - - return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s", - Q, fi.mi.table, Q, - Q, fi.column, Q, - typ, getColumnDefault(fi), - ) -} - -// create database creation string. -func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) { - if len(modelCache.cache) == 0 { - fmt.Println("no Model found, need register your model") - os.Exit(2) - } - - Q := al.DbBaser.TableQuote() - T := al.DbBaser.DbTypes() - sep := fmt.Sprintf("%s, %s", Q, Q) - - tableIndexes = make(map[string][]dbIndex) - - for _, mi := range modelCache.allOrdered() { - sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName) - sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50)) - - sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q) - - columns := make([]string, 0, len(mi.fields.fieldsDB)) - - sqlIndexes := [][]string{} - - for _, fi := range mi.fields.fieldsDB { - - column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q) - col := getColumnTyp(al, fi) - - if fi.auto { - switch al.Driver { - case DRSqlite, DRPostgres: - column += T["auto"] - default: - column += col + " " + T["auto"] - } - } else if fi.pk { - column += col + " " + T["pk"] - } else { - column += col - - if !fi.null { - column += " " + "NOT NULL" - } - - //if fi.initial.String() != "" { - // column += " DEFAULT " + fi.initial.String() - //} - - // Append attribute DEFAULT - column += getColumnDefault(fi) - - if fi.unique { - column += " " + "UNIQUE" - } - - if fi.index { - sqlIndexes = append(sqlIndexes, []string{fi.column}) - } - } - - if strings.Contains(column, "%COL%") { - column = strings.Replace(column, "%COL%", fi.column, -1) - } - - if fi.description != "" { - column += " " + fmt.Sprintf("COMMENT '%s'",fi.description) - } - - columns = append(columns, column) - } - - if mi.model != nil { - allnames := getTableUnique(mi.addrField) - if !mi.manual && len(mi.uniques) > 0 { - allnames = append(allnames, mi.uniques) - } - for _, names := range allnames { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName)) - } - } - column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q) - columns = append(columns, column) - } - } - - sql += strings.Join(columns, ",\n") - sql += "\n)" - - if al.Driver == DRMySQL { - var engine string - if mi.model != nil { - engine = getTableEngine(mi.addrField) - } - if engine == "" { - engine = al.Engine - } - sql += " ENGINE=" + engine - } - - sql += ";" - sqls = append(sqls, sql) - - if mi.model != nil { - for _, names := range getTableIndex(mi.addrField) { - cols := make([]string, 0, len(names)) - for _, name := range names { - if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { - cols = append(cols, fi.column) - } else { - panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName)) - } - } - sqlIndexes = append(sqlIndexes, cols) - } - } - - for _, names := range sqlIndexes { - name := mi.table + "_" + strings.Join(names, "_") - cols := strings.Join(names, sep) - sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q) - - index := dbIndex{} - index.Table = mi.table - index.Name = name - index.SQL = sql - - tableIndexes[mi.table] = append(tableIndexes[mi.table], index) - } - - } - - return -} - -// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands -func getColumnDefault(fi *fieldInfo) string { - var ( - v, t, d string - ) - - // Skip default attribute if field is in relations - if fi.rel || fi.reverse { - return v - } - - t = " DEFAULT '%s' " - - // These defaults will be useful if there no config value orm:"default" and NOT NULL is on - switch fi.fieldType { - case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField: - return v - - case TypeBitField, TypeSmallIntegerField, TypeIntegerField, - TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField, - TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField, - TypeDecimalField: - t = " DEFAULT %s " - d = "0" - case TypeBooleanField: - t = " DEFAULT %s " - d = "FALSE" - case TypeJSONField, TypeJsonbField: - d = "{}" - } - - if fi.colDefault { - if !fi.initial.Exist() { - v = fmt.Sprintf(t, "") - } else { - v = fmt.Sprintf(t, fi.initial.String()) - } - } else { - if !fi.null { - v = fmt.Sprintf(t, d) - } - } - - return v -} diff --git a/orm/db_alias.go b/orm/db_alias.go deleted file mode 100644 index a43e70e3b0..0000000000 --- a/orm/db_alias.go +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "database/sql" - "fmt" - "reflect" - "sync" - "time" -) - -// DriverType database driver constant int. -type DriverType int - -// Enum the Database driver -const ( - _ DriverType = iota // int enum type - DRMySQL // mysql - DRSqlite // sqlite - DROracle // oracle - DRPostgres // pgsql - DRTiDB // TiDB -) - -// database driver string. -type driver string - -// get type constant int of current driver.. -func (d driver) Type() DriverType { - a, _ := dataBaseCache.get(string(d)) - return a.Driver -} - -// get name of current driver -func (d driver) Name() string { - return string(d) -} - -// check driver iis implemented Driver interface or not. -var _ Driver = new(driver) - -var ( - dataBaseCache = &_dbCache{cache: make(map[string]*alias)} - drivers = map[string]DriverType{ - "mysql": DRMySQL, - "postgres": DRPostgres, - "sqlite3": DRSqlite, - "tidb": DRTiDB, - "oracle": DROracle, - "oci8": DROracle, // github.com/mattn/go-oci8 - "ora": DROracle, //https://github.com/rana/ora - } - dbBasers = map[DriverType]dbBaser{ - DRMySQL: newdbBaseMysql(), - DRSqlite: newdbBaseSqlite(), - DROracle: newdbBaseOracle(), - DRPostgres: newdbBasePostgres(), - DRTiDB: newdbBaseTidb(), - } -) - -// database alias cacher. -type _dbCache struct { - mux sync.RWMutex - cache map[string]*alias -} - -// add database alias with original name. -func (ac *_dbCache) add(name string, al *alias) (added bool) { - ac.mux.Lock() - defer ac.mux.Unlock() - if _, ok := ac.cache[name]; !ok { - ac.cache[name] = al - added = true - } - return -} - -// get database alias if cached. -func (ac *_dbCache) get(name string) (al *alias, ok bool) { - ac.mux.RLock() - defer ac.mux.RUnlock() - al, ok = ac.cache[name] - return -} - -// get default alias. -func (ac *_dbCache) getDefault() (al *alias) { - al, _ = ac.get("default") - return -} - -type alias struct { - Name string - Driver DriverType - DriverName string - DataSource string - MaxIdleConns int - MaxOpenConns int - DB *sql.DB - DbBaser dbBaser - TZ *time.Location - Engine string -} - -func detectTZ(al *alias) { - // orm timezone system match database - // default use Local - al.TZ = DefaultTimeLoc - - if al.DriverName == "sphinx" { - return - } - - switch al.Driver { - case DRMySQL: - row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)") - var tz string - row.Scan(&tz) - if len(tz) >= 8 { - if tz[0] != '-' { - tz = "+" + tz - } - t, err := time.Parse("-07:00:00", tz) - if err == nil { - if t.Location().String() != "" { - al.TZ = t.Location() - } - } else { - DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) - } - } - - // get default engine from current database - row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'") - var engine string - var tx bool - row.Scan(&engine, &tx) - - if engine != "" { - al.Engine = engine - } else { - al.Engine = "INNODB" - } - - case DRSqlite, DROracle: - al.TZ = time.UTC - - case DRPostgres: - row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')") - var tz string - row.Scan(&tz) - loc, err := time.LoadLocation(tz) - if err == nil { - al.TZ = loc - } else { - DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error()) - } - } -} - -func addAliasWthDB(aliasName, driverName string, db *sql.DB) (*alias, error) { - al := new(alias) - al.Name = aliasName - al.DriverName = driverName - al.DB = db - - if dr, ok := drivers[driverName]; ok { - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - err := db.Ping() - if err != nil { - return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error()) - } - - if !dataBaseCache.add(aliasName, al) { - return nil, fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName) - } - - return al, nil -} - -// AddAliasWthDB add a aliasName for the drivename -func AddAliasWthDB(aliasName, driverName string, db *sql.DB) error { - _, err := addAliasWthDB(aliasName, driverName, db) - return err -} - -// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args. -func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) error { - var ( - err error - db *sql.DB - al *alias - ) - - db, err = sql.Open(driverName, dataSource) - if err != nil { - err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error()) - goto end - } - - al, err = addAliasWthDB(aliasName, driverName, db) - if err != nil { - goto end - } - - al.DataSource = dataSource - - detectTZ(al) - - for i, v := range params { - switch i { - case 0: - SetMaxIdleConns(al.Name, v) - case 1: - SetMaxOpenConns(al.Name, v) - } - } - -end: - if err != nil { - if db != nil { - db.Close() - } - DebugLog.Println(err.Error()) - } - - return err -} - -// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type. -func RegisterDriver(driverName string, typ DriverType) error { - if t, ok := drivers[driverName]; !ok { - drivers[driverName] = typ - } else { - if t != typ { - return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName) - } - } - return nil -} - -// SetDataBaseTZ Change the database default used timezone -func SetDataBaseTZ(aliasName string, tz *time.Location) error { - if al, ok := dataBaseCache.get(aliasName); ok { - al.TZ = tz - } else { - return fmt.Errorf("DataBase alias name `%s` not registered", aliasName) - } - return nil -} - -// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name -func SetMaxIdleConns(aliasName string, maxIdleConns int) { - al := getDbAlias(aliasName) - al.MaxIdleConns = maxIdleConns - al.DB.SetMaxIdleConns(maxIdleConns) -} - -// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name -func SetMaxOpenConns(aliasName string, maxOpenConns int) { - al := getDbAlias(aliasName) - al.MaxOpenConns = maxOpenConns - // for tip go 1.2 - if fun := reflect.ValueOf(al.DB).MethodByName("SetMaxOpenConns"); fun.IsValid() { - fun.Call([]reflect.Value{reflect.ValueOf(maxOpenConns)}) - } -} - -// GetDB Get *sql.DB from registered database by db alias name. -// Use "default" as alias name if you not set. -func GetDB(aliasNames ...string) (*sql.DB, error) { - var name string - if len(aliasNames) > 0 { - name = aliasNames[0] - } else { - name = "default" - } - al, ok := dataBaseCache.get(name) - if ok { - return al.DB, nil - } - return nil, fmt.Errorf("DataBase of alias name `%s` not found", name) -} diff --git a/orm/models.go b/orm/models.go deleted file mode 100644 index 4776bcba6c..0000000000 --- a/orm/models.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "sync" -) - -const ( - odCascade = "cascade" - odSetNULL = "set_null" - odSetDefault = "set_default" - odDoNothing = "do_nothing" - defaultStructTagName = "orm" - defaultStructTagDelim = ";" -) - -var ( - modelCache = &_modelCache{ - cache: make(map[string]*modelInfo), - cacheByFullName: make(map[string]*modelInfo), - } -) - -// model info collection -type _modelCache struct { - sync.RWMutex // only used outsite for bootStrap - orders []string - cache map[string]*modelInfo - cacheByFullName map[string]*modelInfo - done bool -} - -// get all model info -func (mc *_modelCache) all() map[string]*modelInfo { - m := make(map[string]*modelInfo, len(mc.cache)) - for k, v := range mc.cache { - m[k] = v - } - return m -} - -// get ordered model info -func (mc *_modelCache) allOrdered() []*modelInfo { - m := make([]*modelInfo, 0, len(mc.orders)) - for _, table := range mc.orders { - m = append(m, mc.cache[table]) - } - return m -} - -// get model info by table name -func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { - mi, ok = mc.cache[table] - return -} - -// get model info by full name -func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) { - mi, ok = mc.cacheByFullName[name] - return -} - -// set model info to collection -func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { - mii := mc.cache[table] - mc.cache[table] = mi - mc.cacheByFullName[mi.fullName] = mi - if mii == nil { - mc.orders = append(mc.orders, table) - } - return mii -} - -// clean all model info. -func (mc *_modelCache) clean() { - mc.orders = make([]string, 0) - mc.cache = make(map[string]*modelInfo) - mc.cacheByFullName = make(map[string]*modelInfo) - mc.done = false -} - -// ResetModelCache Clean model cache. Then you can re-RegisterModel. -// Common use this api for test case. -func ResetModelCache() { - modelCache.clean() -} diff --git a/orm/models_boot.go b/orm/models_boot.go deleted file mode 100644 index badfd11b49..0000000000 --- a/orm/models_boot.go +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "reflect" - "strings" -) - -// register models. -// PrefixOrSuffix means table name prefix or suffix. -// isPrefix whether the prefix is prefix or suffix -func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { - val := reflect.ValueOf(model) - typ := reflect.Indirect(val).Type() - - if val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - // For this case: - // u := &User{} - // registerModel(&u) - if typ.Kind() == reflect.Ptr { - panic(fmt.Errorf(" only allow ptr model struct, it looks you use two reference to the struct `%s`", typ)) - } - - table := getTableName(val) - - if PrefixOrSuffix != "" { - if isPrefix { - table = PrefixOrSuffix + table - } else { - table = table + PrefixOrSuffix - } - } - // models's fullname is pkgpath + struct name - name := getFullName(typ) - if _, ok := modelCache.getByFullName(name); ok { - fmt.Printf(" model `%s` repeat register, must be unique\n", name) - os.Exit(2) - } - - if _, ok := modelCache.get(table); ok { - fmt.Printf(" table name `%s` repeat register, must be unique\n", table) - os.Exit(2) - } - - mi := newModelInfo(val) - if mi.fields.pk == nil { - outFor: - for _, fi := range mi.fields.fieldsDB { - if strings.ToLower(fi.name) == "id" { - switch fi.addrValue.Elem().Kind() { - case reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - fi.auto = true - fi.pk = true - mi.fields.pk = fi - break outFor - } - } - } - - if mi.fields.pk == nil { - fmt.Printf(" `%s` needs a primary key field, default is to use 'id' if not set\n", name) - os.Exit(2) - } - - } - - mi.table = table - mi.pkg = typ.PkgPath() - mi.model = model - mi.manual = true - - modelCache.set(table, mi) -} - -// bootstrap models -func bootStrap() { - if modelCache.done { - return - } - var ( - err error - models map[string]*modelInfo - ) - if dataBaseCache.getDefault() == nil { - err = fmt.Errorf("must have one register DataBase alias named `default`") - goto end - } - - // set rel and reverse model - // RelManyToMany set the relTable - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.columns { - if fi.rel || fi.reverse { - elm := fi.addrValue.Type().Elem() - if fi.fieldType == RelReverseMany || fi.fieldType == RelManyToMany { - elm = elm.Elem() - } - // check the rel or reverse model already register - name := getFullName(elm) - mii, ok := modelCache.getByFullName(name) - if !ok || mii.pkg != elm.PkgPath() { - err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) - goto end - } - fi.relModelInfo = mii - - switch fi.fieldType { - case RelManyToMany: - if fi.relThrough != "" { - if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) { - pn := fi.relThrough[:i] - rmi, ok := modelCache.getByFullName(fi.relThrough) - if !ok || pn != rmi.pkg { - err = fmt.Errorf("field `%s` wrong rel_through value `%s` cannot find table", fi.fullName, fi.relThrough) - goto end - } - fi.relThroughModelInfo = rmi - fi.relTable = rmi.table - } else { - err = fmt.Errorf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough) - goto end - } - } else { - i := newM2MModelInfo(mi, mii) - if fi.relTable != "" { - i.table = fi.relTable - } - if v := modelCache.set(i.table, i); v != nil { - err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable) - goto end - } - fi.relTable = i.table - fi.relThroughModelInfo = i - } - - fi.relThroughModelInfo.isThrough = true - } - } - } - } - - // check the rel filed while the relModelInfo also has filed point to current model - // if not exist, add a new field to the relModelInfo - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelForeignKey, RelOneToOne, RelManyToMany: - inModel := false - for _, ffi := range fi.relModelInfo.fields.fieldsReverse { - if ffi.relModelInfo == mi { - inModel = true - break - } - } - if !inModel { - rmi := fi.relModelInfo - ffi := new(fieldInfo) - ffi.name = mi.name - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - ffi.reverse = true - ffi.relModelInfo = mi - ffi.mi = rmi - if fi.fieldType == RelOneToOne { - ffi.fieldType = RelReverseOne - } else { - ffi.fieldType = RelReverseMany - } - if !rmi.fields.Add(ffi) { - added := false - for cnt := 0; cnt < 5; cnt++ { - ffi.name = fmt.Sprintf("%s%d", mi.name, cnt) - ffi.column = ffi.name - ffi.fullName = rmi.fullName + "." + ffi.name - if added = rmi.fields.Add(ffi); added { - break - } - } - if !added { - panic(fmt.Errorf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName)) - } - } - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsRel { - switch fi.fieldType { - case RelManyToMany: - for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel { - switch ffi.fieldType { - case RelOneToOne, RelForeignKey: - if ffi.relModelInfo == fi.relModelInfo { - fi.reverseFieldInfoTwo = ffi - } - if ffi.relModelInfo == mi { - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - } - } - } - if fi.reverseFieldInfoTwo == nil { - err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct", - fi.relThroughModelInfo.fullName) - goto end - } - } - } - } - - models = modelCache.all() - for _, mi := range models { - for _, fi := range mi.fields.fieldsReverse { - switch fi.fieldType { - case RelReverseOne: - found := false - mForA: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - break mForA - } - } - if !found { - err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - case RelReverseMany: - found := false - mForB: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { - if ffi.relModelInfo == mi { - found = true - fi.reverseField = ffi.name - fi.reverseFieldInfo = ffi - - ffi.reverseField = fi.name - ffi.reverseFieldInfo = fi - - break mForB - } - } - if !found { - mForC: - for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { - conditions := fi.relThrough != "" && fi.relThrough == ffi.relThrough || - fi.relTable != "" && fi.relTable == ffi.relTable || - fi.relThrough == "" && fi.relTable == "" - if ffi.relModelInfo == mi && conditions { - found = true - - fi.reverseField = ffi.reverseFieldInfoTwo.name - fi.reverseFieldInfo = ffi.reverseFieldInfoTwo - fi.relThroughModelInfo = ffi.relThroughModelInfo - fi.reverseFieldInfoTwo = ffi.reverseFieldInfo - fi.reverseFieldInfoM2M = ffi - ffi.reverseFieldInfoM2M = fi - - break mForC - } - } - } - if !found { - err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) - goto end - } - } - } - } - -end: - if err != nil { - fmt.Println(err) - os.Exit(2) - } -} - -// RegisterModel register models -func RegisterModel(models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModel must be run before BootStrap")) - } - RegisterModelWithPrefix("", models...) -} - -// RegisterModelWithPrefix register models with a prefix -func RegisterModelWithPrefix(prefix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithPrefix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(prefix, model, true) - } -} - -// RegisterModelWithSuffix register models with a suffix -func RegisterModelWithSuffix(suffix string, models ...interface{}) { - if modelCache.done { - panic(fmt.Errorf("RegisterModelWithSuffix must be run before BootStrap")) - } - - for _, model := range models { - registerModel(suffix, model, false) - } -} - -// BootStrap bootstrap models. -// make all model parsed and can not add more models -func BootStrap() { - if modelCache.done { - return - } - modelCache.Lock() - defer modelCache.Unlock() - bootStrap() - modelCache.done = true -} diff --git a/orm/models_info_m.go b/orm/models_info_m.go deleted file mode 100644 index a4d733b6ce..0000000000 --- a/orm/models_info_m.go +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "os" - "reflect" -) - -// single model info -type modelInfo struct { - pkg string - name string - fullName string - table string - model interface{} - fields *fields - manual bool - addrField reflect.Value //store the original struct value - uniques []string - isThrough bool -} - -// new model info -func newModelInfo(val reflect.Value) (mi *modelInfo) { - mi = &modelInfo{} - mi.fields = newFields() - ind := reflect.Indirect(val) - mi.addrField = val - mi.name = ind.Type().Name() - mi.fullName = getFullName(ind.Type()) - addModelFields(mi, ind, "", []int{}) - return -} - -// index: FieldByIndex returns the nested field corresponding to index -func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) { - var ( - err error - fi *fieldInfo - sf reflect.StructField - ) - - for i := 0; i < ind.NumField(); i++ { - field := ind.Field(i) - sf = ind.Type().Field(i) - // if the field is unexported skip - if sf.PkgPath != "" { - continue - } - // add anonymous struct fields - if sf.Anonymous { - addModelFields(mi, field, mName+"."+sf.Name, append(index, i)) - continue - } - - fi, err = newFieldInfo(mi, field, sf, mName) - if err == errSkipField { - err = nil - continue - } else if err != nil { - break - } - //record current field index - fi.fieldIndex = append(fi.fieldIndex, index...) - fi.fieldIndex = append(fi.fieldIndex, i) - fi.mi = mi - fi.inModel = true - if !mi.fields.Add(fi) { - err = fmt.Errorf("duplicate column name: %s", fi.column) - break - } - if fi.pk { - if mi.fields.pk != nil { - err = fmt.Errorf("one model must have one pk field only") - break - } else { - mi.fields.pk = fi - } - } - } - - if err != nil { - fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) - os.Exit(2) - } -} - -// combine related model info to new model info. -// prepare for relation models query. -func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) { - mi = new(modelInfo) - mi.fields = newFields() - mi.table = m1.table + "_" + m2.table + "s" - mi.name = camelString(mi.table) - mi.fullName = m1.pkg + "." + mi.name - - fa := new(fieldInfo) // pk - f1 := new(fieldInfo) // m1 table RelForeignKey - f2 := new(fieldInfo) // m2 table RelForeignKey - fa.fieldType = TypeBigIntegerField - fa.auto = true - fa.pk = true - fa.dbcol = true - fa.name = "Id" - fa.column = "id" - fa.fullName = mi.fullName + "." + fa.name - - f1.dbcol = true - f2.dbcol = true - f1.fieldType = RelForeignKey - f2.fieldType = RelForeignKey - f1.name = camelString(m1.table) - f2.name = camelString(m2.table) - f1.fullName = mi.fullName + "." + f1.name - f2.fullName = mi.fullName + "." + f2.name - f1.column = m1.table + "_id" - f2.column = m2.table + "_id" - f1.rel = true - f2.rel = true - f1.relTable = m1.table - f2.relTable = m2.table - f1.relModelInfo = m1 - f2.relModelInfo = m2 - f1.mi = mi - f2.mi = mi - - mi.fields.Add(fa) - mi.fields.Add(f1) - mi.fields.Add(f2) - mi.fields.pk = fa - - mi.uniques = []string{f1.column, f2.column} - return -} diff --git a/orm/orm.go b/orm/orm.go deleted file mode 100644 index d322881b4c..0000000000 --- a/orm/orm.go +++ /dev/null @@ -1,575 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.8 - -// Package orm provide ORM for MySQL/PostgreSQL/sqlite -// Simple Usage -// -// package main -// -// import ( -// "fmt" -// "github.com/astaxie/beego/orm" -// _ "github.com/go-sql-driver/mysql" // import your used driver -// ) -// -// // Model Struct -// type User struct { -// Id int `orm:"auto"` -// Name string `orm:"size(100)"` -// } -// -// func init() { -// orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30) -// } -// -// func main() { -// o := orm.NewOrm() -// user := User{Name: "slene"} -// // insert -// id, err := o.Insert(&user) -// // update -// user.Name = "astaxie" -// num, err := o.Update(&user) -// // read one -// u := User{Id: user.Id} -// err = o.Read(&u) -// // delete -// num, err = o.Delete(&u) -// } -// -// more docs: http://beego.me/docs/mvc/model/overview.md -package orm - -import ( - "context" - "database/sql" - "errors" - "fmt" - "os" - "reflect" - "time" -) - -// DebugQueries define the debug -const ( - DebugQueries = iota -) - -// Define common vars -var ( - Debug = false - DebugLog = NewLog(os.Stdout) - DefaultRowsLimit = -1 - DefaultRelsDepth = 2 - DefaultTimeLoc = time.Local - ErrTxHasBegan = errors.New(" transaction already begin") - ErrTxDone = errors.New(" transaction not begin") - ErrMultiRows = errors.New(" return multi rows") - ErrNoRows = errors.New(" no row found") - ErrStmtClosed = errors.New(" stmt already closed") - ErrArgs = errors.New(" args error may be empty") - ErrNotImplement = errors.New("have not implement") -) - -// Params stores the Params -type Params map[string]interface{} - -// ParamsList stores paramslist -type ParamsList []interface{} - -type orm struct { - alias *alias - db dbQuerier - isTx bool -} - -var _ Ormer = new(orm) - -// get model info and model reflect value -func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) { - val := reflect.ValueOf(md) - ind = reflect.Indirect(val) - typ := ind.Type() - if needPtr && val.Kind() != reflect.Ptr { - panic(fmt.Errorf(" cannot use non-ptr model struct `%s`", getFullName(typ))) - } - name := getFullName(typ) - if mi, ok := modelCache.getByFullName(name); ok { - return mi, ind - } - panic(fmt.Errorf(" table: `%s` not found, make sure it was registered with `RegisterModel()`", name)) -} - -// get field info from model info by given field name -func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { - fi, ok := mi.fields.GetByAny(name) - if !ok { - panic(fmt.Errorf(" cannot find field `%s` for model `%s`", name, mi.fullName)) - } - return fi -} - -// read data to model -func (o *orm) Read(md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) -} - -// read data to model, like Read(), but use "SELECT FOR UPDATE" form -func (o *orm) ReadForUpdate(md interface{}, cols ...string) error { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true) -} - -// Try to read a row from the database, or insert one if it doesn't exist -func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) { - cols = append([]string{col1}, cols...) - mi, ind := o.getMiInd(md, true) - err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false) - if err == ErrNoRows { - // Create - id, err := o.Insert(md) - return (err == nil), id, err - } - - id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - id = int64(vid.Uint()) - } else if mi.fields.pk.rel { - return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name) - } else { - id = vid.Int() - } - - return false, id, err -} - -// insert model data to database -func (o *orm) Insert(md interface{}) (int64, error) { - mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) - if err != nil { - return id, err - } - - o.setPk(mi, ind, id) - - return id, nil -} - -// set auto pk field -func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) { - if mi.fields.pk.auto { - if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id)) - } else { - ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id) - } - } -} - -// insert some models to database -func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) { - var cnt int64 - - sind := reflect.Indirect(reflect.ValueOf(mds)) - - switch sind.Kind() { - case reflect.Array, reflect.Slice: - if sind.Len() == 0 { - return cnt, ErrArgs - } - default: - return cnt, ErrArgs - } - - if bulk <= 1 { - for i := 0; i < sind.Len(); i++ { - ind := reflect.Indirect(sind.Index(i)) - mi, _ := o.getMiInd(ind.Interface(), false) - id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) - if err != nil { - return cnt, err - } - - o.setPk(mi, ind, id) - - cnt++ - } - } else { - mi, _ := o.getMiInd(sind.Index(0).Interface(), false) - return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ) - } - return cnt, nil -} - -// InsertOrUpdate data to database -func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...) - if err != nil { - return id, err - } - - o.setPk(mi, ind, id) - - return id, nil -} - -// update model to database. -// cols set the columns those want to update. -func (o *orm) Update(md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - return o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) -} - -// delete model in database -// cols shows the delete conditions values read from. default is pk -func (o *orm) Delete(md interface{}, cols ...string) (int64, error) { - mi, ind := o.getMiInd(md, true) - num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols) - if err != nil { - return num, err - } - if num > 0 { - o.setPk(mi, ind, 0) - } - return num, nil -} - -// create a models to models queryer -func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { - mi, ind := o.getMiInd(md, true) - fi := o.getFieldInfo(mi, name) - - switch { - case fi.fieldType == RelManyToMany: - case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough: - default: - panic(fmt.Errorf(" model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName)) - } - - return newQueryM2M(md, o, mi, fi, ind) -} - -// load related models to md model. -// args are limit, offset int and order string. -// -// example: -// orm.LoadRelated(post,"Tags") -// for _,tag := range post.Tags{...} -// -// make sure the relation is defined in model struct tags. -func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) { - _, fi, ind, qseter := o.queryRelated(md, name) - - qs := qseter.(*querySet) - - var relDepth int - var limit, offset int64 - var order string - for i, arg := range args { - switch i { - case 0: - if v, ok := arg.(bool); ok { - if v { - relDepth = DefaultRelsDepth - } - } else if v, ok := arg.(int); ok { - relDepth = v - } - case 1: - limit = ToInt64(arg) - case 2: - offset = ToInt64(arg) - case 3: - order, _ = arg.(string) - } - } - - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelReverseOne: - limit = 1 - offset = 0 - } - - qs.limit = limit - qs.offset = offset - qs.relDepth = relDepth - - if len(order) > 0 { - qs.orders = []string{order} - } - - find := ind.FieldByIndex(fi.fieldIndex) - - var nums int64 - var err error - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelReverseOne: - val := reflect.New(find.Type().Elem()) - container := val.Interface() - err = qs.One(container) - if err == nil { - find.Set(val) - nums = 1 - } - default: - nums, err = qs.All(find.Addr().Interface()) - } - - return nums, err -} - -// return a QuerySeter for related models to md model. -// it can do all, update, delete in QuerySeter. -// example: -// qs := orm.QueryRelated(post,"Tag") -// qs.All(&[]*Tag{}) -// -func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { - // is this api needed ? - _, _, _, qs := o.queryRelated(md, name) - return qs -} - -// get QuerySeter for related models to md model -func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { - mi, ind := o.getMiInd(md, true) - fi := o.getFieldInfo(mi, name) - - _, _, exist := getExistPk(mi, ind) - if !exist { - panic(ErrMissPK) - } - - var qs *querySet - - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelManyToMany: - if !fi.inModel { - break - } - qs = o.getRelQs(md, mi, fi) - case RelReverseOne, RelReverseMany: - if !fi.inModel { - break - } - qs = o.getReverseQs(md, mi, fi) - } - - if qs == nil { - panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel/reverse field", md, name)) - } - - return mi, fi, ind, qs -} - -// get reverse relation QuerySeter -func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { - switch fi.fieldType { - case RelReverseOne, RelReverseMany: - default: - panic(fmt.Errorf(" name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName)) - } - - var q *querySet - - if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough { - q = newQuerySet(o, fi.relModelInfo).(*querySet) - q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) - } else { - q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet) - q.cond = NewCondition().And(fi.reverseFieldInfo.column, md) - } - - return q -} - -// get relation QuerySeter -func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet { - switch fi.fieldType { - case RelOneToOne, RelForeignKey, RelManyToMany: - default: - panic(fmt.Errorf(" name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName)) - } - - q := newQuerySet(o, fi.relModelInfo).(*querySet) - q.cond = NewCondition() - - if fi.fieldType == RelManyToMany { - q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md) - } else { - q.cond = q.cond.And(fi.reverseFieldInfo.column, md) - } - - return q -} - -// return a QuerySeter for table operations. -// table name can be string or struct. -// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), -func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { - var name string - if table, ok := ptrStructOrTableName.(string); ok { - name = nameStrategyMap[defaultNameStrategy](table) - if mi, ok := modelCache.get(name); ok { - qs = newQuerySet(o, mi) - } - } else { - name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName))) - if mi, ok := modelCache.getByFullName(name); ok { - qs = newQuerySet(o, mi) - } - } - if qs == nil { - panic(fmt.Errorf(" table name: `%s` not exists", name)) - } - return -} - -// switch to another registered database driver by given name. -func (o *orm) Using(name string) error { - if o.isTx { - panic(fmt.Errorf(" transaction has been start, cannot change db")) - } - if al, ok := dataBaseCache.get(name); ok { - o.alias = al - if Debug { - o.db = newDbQueryLog(al, al.DB) - } else { - o.db = al.DB - } - } else { - return fmt.Errorf(" unknown db alias name `%s`", name) - } - return nil -} - -// begin transaction -func (o *orm) Begin() error { - return o.BeginTx(context.Background(), nil) -} - -func (o *orm) BeginTx(ctx context.Context, opts *sql.TxOptions) error { - if o.isTx { - return ErrTxHasBegan - } - var tx *sql.Tx - tx, err := o.db.(txer).BeginTx(ctx, opts) - if err != nil { - return err - } - o.isTx = true - if Debug { - o.db.(*dbQueryLog).SetDB(tx) - } else { - o.db = tx - } - return nil -} - -// commit transaction -func (o *orm) Commit() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Commit() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// rollback transaction -func (o *orm) Rollback() error { - if !o.isTx { - return ErrTxDone - } - err := o.db.(txEnder).Rollback() - if err == nil { - o.isTx = false - o.Using(o.alias.Name) - } else if err == sql.ErrTxDone { - return ErrTxDone - } - return err -} - -// return a raw query seter for raw sql string. -func (o *orm) Raw(query string, args ...interface{}) RawSeter { - return newRawSet(o, query, args) -} - -// return current using database Driver -func (o *orm) Driver() Driver { - return driver(o.alias.Name) -} - -// return sql.DBStats for current database -func (o *orm) DBStats() *sql.DBStats { - if o.alias != nil && o.alias.DB != nil { - stats := o.alias.DB.Stats() - return &stats - } - - return nil -} - -// NewOrm create new orm -func NewOrm() Ormer { - BootStrap() // execute only once - - o := new(orm) - err := o.Using("default") - if err != nil { - panic(err) - } - return o -} - -// NewOrmWithDB create a new ormer object with specify *sql.DB for query -func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) { - var al *alias - - if dr, ok := drivers[driverName]; ok { - al = new(alias) - al.DbBaser = dbBasers[dr] - al.Driver = dr - } else { - return nil, fmt.Errorf("driver name `%s` have not registered", driverName) - } - - al.Name = aliasName - al.DriverName = driverName - al.DB = db - - detectTZ(al) - - o := new(orm) - o.alias = al - - if Debug { - o.db = newDbQueryLog(o.alias, db) - } else { - o.db = db - } - - return o, nil -} diff --git a/orm/orm_queryset.go b/orm/orm_queryset.go deleted file mode 100644 index 7f2fdb8fb9..0000000000 --- a/orm/orm_queryset.go +++ /dev/null @@ -1,294 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "fmt" -) - -type colValue struct { - value int64 - opt operator -} - -type operator int - -// define Col operations -const ( - ColAdd operator = iota - ColMinus - ColMultiply - ColExcept -) - -// ColValue do the field raw changes. e.g Nums = Nums + 10. usage: -// Params{ -// "Nums": ColValue(Col_Add, 10), -// } -func ColValue(opt operator, value interface{}) interface{} { - switch opt { - case ColAdd, ColMinus, ColMultiply, ColExcept: - default: - panic(fmt.Errorf("orm.ColValue wrong operator")) - } - v, err := StrTo(ToStr(value)).Int64() - if err != nil { - panic(fmt.Errorf("orm.ColValue doesn't support non string/numeric type, %s", err)) - } - var val colValue - val.value = v - val.opt = opt - return val -} - -// real query struct -type querySet struct { - mi *modelInfo - cond *Condition - related []string - relDepth int - limit int64 - offset int64 - groups []string - orders []string - distinct bool - forupdate bool - orm *orm - ctx context.Context - forContext bool -} - -var _ QuerySeter = new(querySet) - -// add condition expression to QuerySeter. -func (o querySet) Filter(expr string, args ...interface{}) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.And(expr, args...) - return &o -} - -// add raw sql to querySeter. -func (o querySet) FilterRaw(expr string, sql string) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.Raw(expr, sql) - return &o -} - -// add NOT condition to querySeter. -func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter { - if o.cond == nil { - o.cond = NewCondition() - } - o.cond = o.cond.AndNot(expr, args...) - return &o -} - -// set offset number -func (o *querySet) setOffset(num interface{}) { - o.offset = ToInt64(num) -} - -// add LIMIT value. -// args[0] means offset, e.g. LIMIT num,offset. -func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter { - o.limit = ToInt64(limit) - if len(args) > 0 { - o.setOffset(args[0]) - } - return &o -} - -// add OFFSET value -func (o querySet) Offset(offset interface{}) QuerySeter { - o.setOffset(offset) - return &o -} - -// add GROUP expression -func (o querySet) GroupBy(exprs ...string) QuerySeter { - o.groups = exprs - return &o -} - -// add ORDER expression. -// "column" means ASC, "-column" means DESC. -func (o querySet) OrderBy(exprs ...string) QuerySeter { - o.orders = exprs - return &o -} - -// add DISTINCT to SELECT -func (o querySet) Distinct() QuerySeter { - o.distinct = true - return &o -} - -// add FOR UPDATE to SELECT -func (o querySet) ForUpdate() QuerySeter { - o.forupdate = true - return &o -} - -// set relation model to query together. -// it will query relation models and assign to parent model. -func (o querySet) RelatedSel(params ...interface{}) QuerySeter { - if len(params) == 0 { - o.relDepth = DefaultRelsDepth - } else { - for _, p := range params { - switch val := p.(type) { - case string: - o.related = append(o.related, val) - case int: - o.relDepth = val - default: - panic(fmt.Errorf(" wrong param kind: %v", val)) - } - } - } - return &o -} - -// set condition to QuerySeter. -func (o querySet) SetCond(cond *Condition) QuerySeter { - o.cond = cond - return &o -} - -// get condition from QuerySeter -func (o querySet) GetCond() *Condition { - return o.cond -} - -// return QuerySeter execution result number -func (o *querySet) Count() (int64, error) { - return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) -} - -// check result empty or not after QuerySeter executed -func (o *querySet) Exist() bool { - cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) - return cnt > 0 -} - -// execute update with parameters -func (o *querySet) Update(values Params) (int64, error) { - return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ) -} - -// execute delete -func (o *querySet) Delete() (int64, error) { - return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) -} - -// return a insert queryer. -// it can be used in times. -// example: -// i,err := sq.PrepareInsert() -// i.Add(&user1{},&user2{}) -func (o *querySet) PrepareInsert() (Inserter, error) { - return newInsertSet(o.orm, o.mi) -} - -// query all data and map to containers. -// cols means the columns when querying. -func (o *querySet) All(container interface{}, cols ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) -} - -// query one row data and map to containers. -// cols means the columns when querying. -func (o *querySet) One(container interface{}, cols ...string) error { - o.limit = 1 - num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols) - if err != nil { - return err - } - if num == 0 { - return ErrNoRows - } - - if num > 1 { - return ErrMultiRows - } - return nil -} - -// query all data and map to []map[string]interface. -// expres means condition expression. -// it converts data to []map[column]value. -func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) -} - -// query all data and map to [][]interface -// it converts data to [][column_index]value -func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ) -} - -// query all data and map to []interface. -// it's designed for one row record set, auto change to []value, not [][column]value. -func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) { - return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ) -} - -// query all rows into map[string]interface with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to map[string]interface{}{ -// "total": 100, -// "found": 200, -// } -func (o *querySet) RowsToMap(result *Params, keyCol, valueCol string) (int64, error) { - panic(ErrNotImplement) -} - -// query all rows into struct with specify key and value column name. -// keyCol = "name", valueCol = "value" -// table data -// name | value -// total | 100 -// found | 200 -// to struct { -// Total int -// Found int -// } -func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) { - panic(ErrNotImplement) -} - -// set context to QuerySeter. -func (o querySet) WithContext(ctx context.Context) QuerySeter { - o.ctx = ctx - o.forContext = true - return &o -} - -// create new QuerySeter. -func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { - o := new(querySet) - o.mi = mi - o.orm = orm - return o -} diff --git a/orm/qb_tidb.go b/orm/qb_tidb.go deleted file mode 100644 index 87b3ae84f8..0000000000 --- a/orm/qb_tidb.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2015 TiDB Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "fmt" - "strconv" - "strings" -) - -// TiDBQueryBuilder is the SQL build -type TiDBQueryBuilder struct { - Tokens []string -} - -// Select will join the fields -func (qb *TiDBQueryBuilder) Select(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SELECT", strings.Join(fields, CommaSpace)) - return qb -} - -// ForUpdate add the FOR UPDATE clause -func (qb *TiDBQueryBuilder) ForUpdate() QueryBuilder { - qb.Tokens = append(qb.Tokens, "FOR UPDATE") - return qb -} - -// From join the tables -func (qb *TiDBQueryBuilder) From(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "FROM", strings.Join(tables, CommaSpace)) - return qb -} - -// InnerJoin INNER JOIN the table -func (qb *TiDBQueryBuilder) InnerJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INNER JOIN", table) - return qb -} - -// LeftJoin LEFT JOIN the table -func (qb *TiDBQueryBuilder) LeftJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LEFT JOIN", table) - return qb -} - -// RightJoin RIGHT JOIN the table -func (qb *TiDBQueryBuilder) RightJoin(table string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "RIGHT JOIN", table) - return qb -} - -// On join with on cond -func (qb *TiDBQueryBuilder) On(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ON", cond) - return qb -} - -// Where join the Where cond -func (qb *TiDBQueryBuilder) Where(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "WHERE", cond) - return qb -} - -// And join the and cond -func (qb *TiDBQueryBuilder) And(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "AND", cond) - return qb -} - -// Or join the or cond -func (qb *TiDBQueryBuilder) Or(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OR", cond) - return qb -} - -// In join the IN (vals) -func (qb *TiDBQueryBuilder) In(vals ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "IN", "(", strings.Join(vals, CommaSpace), ")") - return qb -} - -// OrderBy join the Order by fields -func (qb *TiDBQueryBuilder) OrderBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "ORDER BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Asc join the asc -func (qb *TiDBQueryBuilder) Asc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "ASC") - return qb -} - -// Desc join the desc -func (qb *TiDBQueryBuilder) Desc() QueryBuilder { - qb.Tokens = append(qb.Tokens, "DESC") - return qb -} - -// Limit join the limit num -func (qb *TiDBQueryBuilder) Limit(limit int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "LIMIT", strconv.Itoa(limit)) - return qb -} - -// Offset join the offset num -func (qb *TiDBQueryBuilder) Offset(offset int) QueryBuilder { - qb.Tokens = append(qb.Tokens, "OFFSET", strconv.Itoa(offset)) - return qb -} - -// GroupBy join the Group by fields -func (qb *TiDBQueryBuilder) GroupBy(fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "GROUP BY", strings.Join(fields, CommaSpace)) - return qb -} - -// Having join the Having cond -func (qb *TiDBQueryBuilder) Having(cond string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "HAVING", cond) - return qb -} - -// Update join the update table -func (qb *TiDBQueryBuilder) Update(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "UPDATE", strings.Join(tables, CommaSpace)) - return qb -} - -// Set join the set kv -func (qb *TiDBQueryBuilder) Set(kv ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "SET", strings.Join(kv, CommaSpace)) - return qb -} - -// Delete join the Delete tables -func (qb *TiDBQueryBuilder) Delete(tables ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "DELETE") - if len(tables) != 0 { - qb.Tokens = append(qb.Tokens, strings.Join(tables, CommaSpace)) - } - return qb -} - -// InsertInto join the insert SQL -func (qb *TiDBQueryBuilder) InsertInto(table string, fields ...string) QueryBuilder { - qb.Tokens = append(qb.Tokens, "INSERT INTO", table) - if len(fields) != 0 { - fieldsStr := strings.Join(fields, CommaSpace) - qb.Tokens = append(qb.Tokens, "(", fieldsStr, ")") - } - return qb -} - -// Values join the Values(vals) -func (qb *TiDBQueryBuilder) Values(vals ...string) QueryBuilder { - valsStr := strings.Join(vals, CommaSpace) - qb.Tokens = append(qb.Tokens, "VALUES", "(", valsStr, ")") - return qb -} - -// Subquery join the sub as alias -func (qb *TiDBQueryBuilder) Subquery(sub string, alias string) string { - return fmt.Sprintf("(%s) AS %s", sub, alias) -} - -// String join all Tokens -func (qb *TiDBQueryBuilder) String() string { - return strings.Join(qb.Tokens, " ") -} diff --git a/orm/types.go b/orm/types.go deleted file mode 100644 index 6c6443c7fc..0000000000 --- a/orm/types.go +++ /dev/null @@ -1,473 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package orm - -import ( - "context" - "database/sql" - "reflect" - "time" -) - -// Driver define database driver -type Driver interface { - Name() string - Type() DriverType -} - -// Fielder define field info -type Fielder interface { - String() string - FieldType() int - SetRaw(interface{}) error - RawValue() interface{} -} - -// Ormer define the orm interface -type Ormer interface { - // read data to model - // for example: - // this will find User by Id field - // u = &User{Id: user.Id} - // err = Ormer.Read(u) - // this will find User by UserName field - // u = &User{UserName: "astaxie", Password: "pass"} - // err = Ormer.Read(u, "UserName") - Read(md interface{}, cols ...string) error - // Like Read(), but with "FOR UPDATE" clause, useful in transaction. - // Some databases are not support this feature. - ReadForUpdate(md interface{}, cols ...string) error - // Try to read a row from the database, or insert one if it doesn't exist - ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) - // insert model data to database - // for example: - // user := new(User) - // id, err = Ormer.Insert(user) - // user must a pointer and Insert will set user's pk field - Insert(interface{}) (int64, error) - // mysql:InsertOrUpdate(model) or InsertOrUpdate(model,"colu=colu+value") - // if colu type is integer : can use(+-*/), string : convert(colu,"value") - // postgres: InsertOrUpdate(model,"conflictColumnName") or InsertOrUpdate(model,"conflictColumnName","colu=colu+value") - // if colu type is integer : can use(+-*/), string : colu || "value" - InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) - // insert some models to database - InsertMulti(bulk int, mds interface{}) (int64, error) - // update model to database. - // cols set the columns those want to update. - // find model by Id(pk) field and update columns specified by fields, if cols is null then update all columns - // for example: - // user := User{Id: 2} - // user.Langs = append(user.Langs, "zh-CN", "en-US") - // user.Extra.Name = "beego" - // user.Extra.Data = "orm" - // num, err = Ormer.Update(&user, "Langs", "Extra") - Update(md interface{}, cols ...string) (int64, error) - // delete model in database - Delete(md interface{}, cols ...string) (int64, error) - // load related models to md model. - // args are limit, offset int and order string. - // - // example: - // Ormer.LoadRelated(post,"Tags") - // for _,tag := range post.Tags{...} - //args[0] bool true useDefaultRelsDepth ; false depth 0 - //args[0] int loadRelationDepth - //args[1] int limit default limit 1000 - //args[2] int offset default offset 0 - //args[3] string order for example : "-Id" - // make sure the relation is defined in model struct tags. - LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) - // create a models to models queryer - // for example: - // post := Post{Id: 4} - // m2m := Ormer.QueryM2M(&post, "Tags") - QueryM2M(md interface{}, name string) QueryM2Mer - // return a QuerySeter for table operations. - // table name can be string or struct. - // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)), - QueryTable(ptrStructOrTableName interface{}) QuerySeter - // switch to another registered database driver by given name. - Using(name string) error - // begin transaction - // for example: - // o := NewOrm() - // err := o.Begin() - // ... - // err = o.Rollback() - Begin() error - // begin transaction with provided context and option - // the provided context is used until the transaction is committed or rolled back. - // if the context is canceled, the transaction will be rolled back. - // the provided TxOptions is optional and may be nil if defaults should be used. - // if a non-default isolation level is used that the driver doesn't support, an error will be returned. - // for example: - // o := NewOrm() - // err := o.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) - // ... - // err = o.Rollback() - BeginTx(ctx context.Context, opts *sql.TxOptions) error - // commit transaction - Commit() error - // rollback transaction - Rollback() error - // return a raw query seter for raw sql string. - // for example: - // ormer.Raw("UPDATE `user` SET `user_name` = ? WHERE `user_name` = ?", "slene", "testing").Exec() - // // update user testing's name to slene - Raw(query string, args ...interface{}) RawSeter - Driver() Driver - DBStats() *sql.DBStats -} - -// Inserter insert prepared statement -type Inserter interface { - Insert(interface{}) (int64, error) - Close() error -} - -// QuerySeter query seter -type QuerySeter interface { - // add condition expression to QuerySeter. - // for example: - // filter by UserName == 'slene' - // qs.Filter("UserName", "slene") - // sql : left outer join profile on t0.id1==t1.id2 where t1.age == 28 - // Filter("profile__Age", 28) - // // time compare - // qs.Filter("created", time.Now()) - Filter(string, ...interface{}) QuerySeter - // add raw sql to querySeter. - // for example: - // qs.FilterRaw("user_id IN (SELECT id FROM profile WHERE age>=18)") - // //sql-> WHERE user_id IN (SELECT id FROM profile WHERE age>=18) - FilterRaw(string, string) QuerySeter - // add NOT condition to querySeter. - // have the same usage as Filter - Exclude(string, ...interface{}) QuerySeter - // set condition to QuerySeter. - // sql's where condition - // cond := orm.NewCondition() - // cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) - // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 - // num, err := qs.SetCond(cond1).Count() - SetCond(*Condition) QuerySeter - // get condition from QuerySeter. - // sql's where condition - // cond := orm.NewCondition() - // cond = cond.And("profile__isnull", false).AndNot("status__in", 1) - // qs = qs.SetCond(cond) - // cond = qs.GetCond() - // cond := cond.Or("profile__age__gt", 2000) - // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 - // num, err := qs.SetCond(cond).Count() - GetCond() *Condition - // add LIMIT value. - // args[0] means offset, e.g. LIMIT num,offset. - // if Limit <= 0 then Limit will be set to default limit ,eg 1000 - // if QuerySeter doesn't call Limit, the sql's Limit will be set to default limit, eg 1000 - // for example: - // qs.Limit(10, 2) - // // sql-> limit 10 offset 2 - Limit(limit interface{}, args ...interface{}) QuerySeter - // add OFFSET value - // same as Limit function's args[0] - Offset(offset interface{}) QuerySeter - // add GROUP BY expression - // for example: - // qs.GroupBy("id") - GroupBy(exprs ...string) QuerySeter - // add ORDER expression. - // "column" means ASC, "-column" means DESC. - // for example: - // qs.OrderBy("-status") - OrderBy(exprs ...string) QuerySeter - // set relation model to query together. - // it will query relation models and assign to parent model. - // for example: - // // will load all related fields use left join . - // qs.RelatedSel().One(&user) - // // will load related field only profile - // qs.RelatedSel("profile").One(&user) - // user.Profile.Age = 32 - RelatedSel(params ...interface{}) QuerySeter - // Set Distinct - // for example: - // o.QueryTable("policy").Filter("Groups__Group__Users__User", user). - // Distinct(). - // All(&permissions) - Distinct() QuerySeter - // set FOR UPDATE to query. - // for example: - // o.QueryTable("user").Filter("uid", uid).ForUpdate().All(&users) - ForUpdate() QuerySeter - // return QuerySeter execution result number - // for example: - // num, err = qs.Filter("profile__age__gt", 28).Count() - Count() (int64, error) - // check result empty or not after QuerySeter executed - // the same as QuerySeter.Count > 0 - Exist() bool - // execute update with parameters - // for example: - // num, err = qs.Filter("user_name", "slene").Update(Params{ - // "Nums": ColValue(Col_Minus, 50), - // }) // user slene's Nums will minus 50 - // num, err = qs.Filter("UserName", "slene").Update(Params{ - // "user_name": "slene2" - // }) // user slene's name will change to slene2 - Update(values Params) (int64, error) - // delete from table - //for example: - // num ,err = qs.Filter("user_name__in", "testing1", "testing2").Delete() - // //delete two user who's name is testing1 or testing2 - Delete() (int64, error) - // return a insert queryer. - // it can be used in times. - // example: - // i,err := sq.PrepareInsert() - // num, err = i.Insert(&user1) // user table will add one record user1 at once - // num, err = i.Insert(&user2) // user table will add one record user2 at once - // err = i.Close() //don't forget call Close - PrepareInsert() (Inserter, error) - // query all data and map to containers. - // cols means the columns when querying. - // for example: - // var users []*User - // qs.All(&users) // users[0],users[1],users[2] ... - All(container interface{}, cols ...string) (int64, error) - // query one row data and map to containers. - // cols means the columns when querying. - // for example: - // var user User - // qs.One(&user) //user.UserName == "slene" - One(container interface{}, cols ...string) error - // query all data and map to []map[string]interface. - // expres means condition expression. - // it converts data to []map[column]value. - // for example: - // var maps []Params - // qs.Values(&maps) //maps[0]["UserName"]=="slene" - Values(results *[]Params, exprs ...string) (int64, error) - // query all data and map to [][]interface - // it converts data to [][column_index]value - // for example: - // var list []ParamsList - // qs.ValuesList(&list) // list[0][1] == "slene" - ValuesList(results *[]ParamsList, exprs ...string) (int64, error) - // query all data and map to []interface. - // it's designed for one column record set, auto change to []value, not [][column]value. - // for example: - // var list ParamsList - // qs.ValuesFlat(&list, "UserName") // list[0] == "slene" - ValuesFlat(result *ParamsList, expr string) (int64, error) - // query all rows into map[string]interface with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to map[string]interface{}{ - // "total": 100, - // "found": 200, - // } - RowsToMap(result *Params, keyCol, valueCol string) (int64, error) - // query all rows into struct with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to struct { - // Total int - // Found int - // } - RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) -} - -// QueryM2Mer model to model query struct -// all operations are on the m2m table only, will not affect the origin model table -type QueryM2Mer interface { - // add models to origin models when creating queryM2M. - // example: - // m2m := orm.QueryM2M(post,"Tag") - // m2m.Add(&Tag1{},&Tag2{}) - // for _,tag := range post.Tags{}{ ... } - // param could also be any of the follow - // []*Tag{{Id:3,Name: "TestTag1"}, {Id:4,Name: "TestTag2"}} - // &Tag{Id:5,Name: "TestTag3"} - // []interface{}{&Tag{Id:6,Name: "TestTag4"}} - // insert one or more rows to m2m table - // make sure the relation is defined in post model struct tag. - Add(...interface{}) (int64, error) - // remove models following the origin model relationship - // only delete rows from m2m table - // for example: - //tag3 := &Tag{Id:5,Name: "TestTag3"} - //num, err = m2m.Remove(tag3) - Remove(...interface{}) (int64, error) - // check model is existed in relationship of origin model - Exist(interface{}) bool - // clean all models in related of origin model - Clear() (int64, error) - // count all related models of origin model - Count() (int64, error) -} - -// RawPreparer raw query statement -type RawPreparer interface { - Exec(...interface{}) (sql.Result, error) - Close() error -} - -// RawSeter raw query seter -// create From Ormer.Raw -// for example: -// sql := fmt.Sprintf("SELECT %sid%s,%sname%s FROM %suser%s WHERE id = ?",Q,Q,Q,Q,Q,Q) -// rs := Ormer.Raw(sql, 1) -type RawSeter interface { - //execute sql and get result - Exec() (sql.Result, error) - //query data and map to container - //for example: - // var name string - // var id int - // rs.QueryRow(&id,&name) // id==2 name=="slene" - QueryRow(containers ...interface{}) error - - // query data rows and map to container - // var ids []int - // var names []int - // query = fmt.Sprintf("SELECT 'id','name' FROM %suser%s", Q, Q) - // num, err = dORM.Raw(query).QueryRows(&ids,&names) // ids=>{1,2},names=>{"nobody","slene"} - QueryRows(containers ...interface{}) (int64, error) - SetArgs(...interface{}) RawSeter - // query data to []map[string]interface - // see QuerySeter's Values - Values(container *[]Params, cols ...string) (int64, error) - // query data to [][]interface - // see QuerySeter's ValuesList - ValuesList(container *[]ParamsList, cols ...string) (int64, error) - // query data to []interface - // see QuerySeter's ValuesFlat - ValuesFlat(container *ParamsList, cols ...string) (int64, error) - // query all rows into map[string]interface with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to map[string]interface{}{ - // "total": 100, - // "found": 200, - // } - RowsToMap(result *Params, keyCol, valueCol string) (int64, error) - // query all rows into struct with specify key and value column name. - // keyCol = "name", valueCol = "value" - // table data - // name | value - // total | 100 - // found | 200 - // to struct { - // Total int - // Found int - // } - RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) (int64, error) - - // return prepared raw statement for used in times. - // for example: - // pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() - // r, err := pre.Exec("name1") // INSERT INTO tag (name) VALUES (`name1`) - Prepare() (RawPreparer, error) -} - -// stmtQuerier statement querier -type stmtQuerier interface { - Close() error - Exec(args ...interface{}) (sql.Result, error) - //ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) - Query(args ...interface{}) (*sql.Rows, error) - //QueryContext(args ...interface{}) (*sql.Rows, error) - QueryRow(args ...interface{}) *sql.Row - //QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row -} - -// db querier -type dbQuerier interface { - Prepare(query string) (*sql.Stmt, error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - Exec(query string, args ...interface{}) (sql.Result, error) - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row - QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row -} - -// type DB interface { -// Begin() (*sql.Tx, error) -// Prepare(query string) (stmtQuerier, error) -// Exec(query string, args ...interface{}) (sql.Result, error) -// Query(query string, args ...interface{}) (*sql.Rows, error) -// QueryRow(query string, args ...interface{}) *sql.Row -// } - -// transaction beginner -type txer interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) -} - -// transaction ending -type txEnder interface { - Commit() error - Rollback() error -} - -// base database struct -type dbBaser interface { - Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string, bool) error - Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - InsertOrUpdate(dbQuerier, *modelInfo, reflect.Value, *alias, ...string) (int64, error) - InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error) - InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error) - InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) - Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) - ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error) - SupportUpdateJoin() bool - UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) - DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - Count(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) - OperatorSQL(string) string - GenerateOperatorSQL(*modelInfo, *fieldInfo, string, []interface{}, *time.Location) (string, []interface{}) - GenerateOperatorLeftCol(*fieldInfo, string, *string) - PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) - ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}, *time.Location) (int64, error) - RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) - MaxLimit() uint64 - TableQuote() string - ReplaceMarks(*string) - HasReturningID(*modelInfo, *string) bool - TimeFromDB(*time.Time, *time.Location) - TimeToDB(*time.Time, *time.Location) - DbTypes() map[string]string - GetTables(dbQuerier) (map[string]bool, error) - GetColumns(dbQuerier, string) (map[string][3]string, error) - ShowTablesQuery() string - ShowColumnsQuery(string) string - IndexExists(dbQuerier, string, string) bool - collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error) - setval(dbQuerier, *modelInfo, []string) error -} diff --git a/parser.go b/parser.go deleted file mode 100644 index 81990a4a06..0000000000 --- a/parser.go +++ /dev/null @@ -1,588 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "encoding/json" - "errors" - "fmt" - "go/ast" - "go/parser" - "go/token" - "io/ioutil" - "os" - "path/filepath" - "regexp" - "sort" - "strconv" - "strings" - "unicode" - - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" -) - -var globalRouterTemplate = `package routers - -import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context/param"{{.globalimport}} -) - -func init() { -{{.globalinfo}} -} -` - -var ( - lastupdateFilename = "lastupdate.tmp" - commentFilename string - pkgLastupdate map[string]int64 - genInfoList map[string][]ControllerComments - - routerHooks = map[string]int{ - "beego.BeforeStatic": BeforeStatic, - "beego.BeforeRouter": BeforeRouter, - "beego.BeforeExec": BeforeExec, - "beego.AfterExec": AfterExec, - "beego.FinishRouter": FinishRouter, - } - - routerHooksMapping = map[int]string{ - BeforeStatic: "beego.BeforeStatic", - BeforeRouter: "beego.BeforeRouter", - BeforeExec: "beego.BeforeExec", - AfterExec: "beego.AfterExec", - FinishRouter: "beego.FinishRouter", - } -) - -const commentPrefix = "commentsRouter_" - -func init() { - pkgLastupdate = make(map[string]int64) -} - -func parserPkg(pkgRealpath, pkgpath string) error { - rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_") - commentFilename, _ = filepath.Rel(AppPath, pkgRealpath) - commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go" - if !compareFile(pkgRealpath) { - logs.Info(pkgRealpath + " no changed") - return nil - } - genInfoList = make(map[string][]ControllerComments) - fileSet := token.NewFileSet() - astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool { - name := info.Name() - return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") - }, parser.ParseComments) - - if err != nil { - return err - } - for _, pkg := range astPkgs { - for _, fl := range pkg.Files { - for _, d := range fl.Decls { - switch specDecl := d.(type) { - case *ast.FuncDecl: - if specDecl.Recv != nil { - exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser - if ok { - parserComments(specDecl, fmt.Sprint(exp.X), pkgpath) - } - } - } - } - } - } - genRouterCode(pkgRealpath) - savetoFile(pkgRealpath) - return nil -} - -type parsedComment struct { - routerPath string - methods []string - params map[string]parsedParam - filters []parsedFilter - imports []parsedImport -} - -type parsedImport struct { - importPath string - importAlias string -} - -type parsedFilter struct { - pattern string - pos int - filter string - params []bool -} - -type parsedParam struct { - name string - datatype string - location string - defValue string - required bool -} - -func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error { - if f.Doc != nil { - parsedComments, err := parseComment(f.Doc.List) - if err != nil { - return err - } - for _, parsedComment := range parsedComments { - if parsedComment.routerPath != "" { - key := pkgpath + ":" + controllerName - cc := ControllerComments{} - cc.Method = f.Name.String() - cc.Router = parsedComment.routerPath - cc.AllowHTTPMethods = parsedComment.methods - cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment) - cc.FilterComments = buildFilters(parsedComment.filters) - cc.ImportComments = buildImports(parsedComment.imports) - genInfoList[key] = append(genInfoList[key], cc) - } - } - } - return nil -} - -func buildImports(pis []parsedImport) []*ControllerImportComments { - var importComments []*ControllerImportComments - - for _, pi := range pis { - importComments = append(importComments, &ControllerImportComments{ - ImportPath: pi.importPath, - ImportAlias: pi.importAlias, - }) - } - - return importComments -} - -func buildFilters(pfs []parsedFilter) []*ControllerFilterComments { - var filterComments []*ControllerFilterComments - - for _, pf := range pfs { - var ( - returnOnOutput bool - resetParams bool - ) - - if len(pf.params) >= 1 { - returnOnOutput = pf.params[0] - } - - if len(pf.params) >= 2 { - resetParams = pf.params[1] - } - - filterComments = append(filterComments, &ControllerFilterComments{ - Filter: pf.filter, - Pattern: pf.pattern, - Pos: pf.pos, - ReturnOnOutput: returnOnOutput, - ResetParams: resetParams, - }) - } - - return filterComments -} - -func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam { - result := make([]*param.MethodParam, 0, len(funcParams)) - for _, fparam := range funcParams { - for _, pName := range fparam.Names { - methodParam := buildMethodParam(fparam, pName.Name, pc) - result = append(result, methodParam) - } - } - return result -} - -func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam { - options := []param.MethodParamOption{} - if cparam, ok := pc.params[name]; ok { - //Build param from comment info - name = cparam.name - if cparam.required { - options = append(options, param.IsRequired) - } - switch cparam.location { - case "body": - options = append(options, param.InBody) - case "header": - options = append(options, param.InHeader) - case "path": - options = append(options, param.InPath) - } - if cparam.defValue != "" { - options = append(options, param.Default(cparam.defValue)) - } - } else { - if paramInPath(name, pc.routerPath) { - options = append(options, param.InPath) - } - } - return param.New(name, options...) -} - -func paramInPath(name, route string) bool { - return strings.HasSuffix(route, ":"+name) || - strings.Contains(route, ":"+name+"/") -} - -var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`) - -func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) { - pcs = []*parsedComment{} - params := map[string]parsedParam{} - filters := []parsedFilter{} - imports := []parsedImport{} - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Param") { - pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param"))) - if len(pv) < 4 { - logs.Error("Invalid @Param format. Needs at least 4 parameters") - } - p := parsedParam{} - names := strings.SplitN(pv[0], "=>", 2) - p.name = names[0] - funcParamName := p.name - if len(names) > 1 { - funcParamName = names[1] - } - p.location = pv[1] - p.datatype = pv[2] - switch len(pv) { - case 5: - p.required, _ = strconv.ParseBool(pv[3]) - case 6: - p.defValue = pv[3] - p.required, _ = strconv.ParseBool(pv[4]) - } - params[funcParamName] = p - } - } - - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Import") { - iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import"))) - if len(iv) == 0 || len(iv) > 2 { - logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters") - continue - } - - p := parsedImport{} - p.importPath = iv[0] - - if len(iv) == 2 { - p.importAlias = iv[1] - } - - imports = append(imports, p) - } - } - -filterLoop: - for _, c := range lines { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@Filter") { - fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter"))) - if len(fv) < 3 { - logs.Error("Invalid @Filter format. Needs at least 3 parameters") - continue filterLoop - } - - p := parsedFilter{} - p.pattern = fv[0] - posName := fv[1] - if pos, exists := routerHooks[posName]; exists { - p.pos = pos - } else { - logs.Error("Invalid @Filter pos: ", posName) - continue filterLoop - } - - p.filter = fv[2] - fvParams := fv[3:] - for _, fvParam := range fvParams { - switch fvParam { - case "true": - p.params = append(p.params, true) - case "false": - p.params = append(p.params, false) - default: - logs.Error("Invalid @Filter param: ", fvParam) - continue filterLoop - } - } - - filters = append(filters, p) - } - } - - for _, c := range lines { - var pc = &parsedComment{} - pc.params = params - pc.filters = filters - pc.imports = imports - - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - if strings.HasPrefix(t, "@router") { - t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) - matches := routeRegex.FindStringSubmatch(t) - if len(matches) == 3 { - pc.routerPath = matches[1] - methods := matches[2] - if methods == "" { - pc.methods = []string{"get"} - //pc.hasGet = true - } else { - pc.methods = strings.Split(methods, ",") - //pc.hasGet = strings.Contains(methods, "get") - } - pcs = append(pcs, pc) - } else { - return nil, errors.New("Router information is missing") - } - } - } - return -} - -// direct copy from bee\g_docs.go -// analysis params return []string -// @Param query form string true "The email for login" -// [query form string true "The email for login"] -func getparams(str string) []string { - var s []rune - var j int - var start bool - var r []string - var quoted int8 - for _, c := range str { - if unicode.IsSpace(c) && quoted == 0 { - if !start { - continue - } else { - start = false - j++ - r = append(r, string(s)) - s = make([]rune, 0) - continue - } - } - - start = true - if c == '"' { - quoted ^= 1 - continue - } - s = append(s, c) - } - if len(s) > 0 { - r = append(r, string(s)) - } - return r -} - -func genRouterCode(pkgRealpath string) { - os.Mkdir(getRouterDir(pkgRealpath), 0755) - logs.Info("generate router from comments") - var ( - globalinfo string - globalimport string - sortKey []string - ) - for k := range genInfoList { - sortKey = append(sortKey, k) - } - sort.Strings(sortKey) - for _, k := range sortKey { - cList := genInfoList[k] - sort.Sort(ControllerCommentsSlice(cList)) - for _, c := range cList { - allmethod := "nil" - if len(c.AllowHTTPMethods) > 0 { - allmethod = "[]string{" - for _, m := range c.AllowHTTPMethods { - allmethod += `"` + m + `",` - } - allmethod = strings.TrimRight(allmethod, ",") + "}" - } - - params := "nil" - if len(c.Params) > 0 { - params = "[]map[string]string{" - for _, p := range c.Params { - for k, v := range p { - params = params + `map[string]string{` + k + `:"` + v + `"},` - } - } - params = strings.TrimRight(params, ",") + "}" - } - - methodParams := "param.Make(" - if len(c.MethodParams) > 0 { - lines := make([]string, 0, len(c.MethodParams)) - for _, m := range c.MethodParams { - lines = append(lines, fmt.Sprint(m)) - } - methodParams += "\n " + - strings.Join(lines, ",\n ") + - ",\n " - } - methodParams += ")" - - imports := "" - if len(c.ImportComments) > 0 { - for _, i := range c.ImportComments { - var s string - if i.ImportAlias != "" { - s = fmt.Sprintf(` - %s "%s"`, i.ImportAlias, i.ImportPath) - } else { - s = fmt.Sprintf(` - "%s"`, i.ImportPath) - } - if !strings.Contains(globalimport, s) { - imports += s - } - } - } - - filters := "" - if len(c.FilterComments) > 0 { - for _, f := range c.FilterComments { - filters += fmt.Sprintf(` &beego.ControllerFilter{ - Pattern: "%s", - Pos: %s, - Filter: %s, - ReturnOnOutput: %v, - ResetParams: %v, - },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams) - } - } - - if filters == "" { - filters = "nil" - } else { - filters = fmt.Sprintf(`[]*beego.ControllerFilter{ -%s - }`, filters) - } - - globalimport += imports - - globalinfo = globalinfo + ` - beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], - beego.ControllerComments{ - Method: "` + strings.TrimSpace(c.Method) + `", - ` + "Router: `" + c.Router + "`" + `, - AllowHTTPMethods: ` + allmethod + `, - MethodParams: ` + methodParams + `, - Filters: ` + filters + `, - Params: ` + params + `}) -` - } - } - - if globalinfo != "" { - f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) - if err != nil { - panic(err) - } - defer f.Close() - - content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1) - content = strings.Replace(content, "{{.globalimport}}", globalimport, -1) - f.WriteString(content) - } -} - -func compareFile(pkgRealpath string) bool { - if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) { - return true - } - if utils.FileExists(lastupdateFilename) { - content, err := ioutil.ReadFile(lastupdateFilename) - if err != nil { - return true - } - json.Unmarshal(content, &pkgLastupdate) - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return true - } - if v, ok := pkgLastupdate[pkgRealpath]; ok { - if lastupdate <= v { - return false - } - } - } - return true -} - -func savetoFile(pkgRealpath string) { - lastupdate, err := getpathTime(pkgRealpath) - if err != nil { - return - } - pkgLastupdate[pkgRealpath] = lastupdate - d, err := json.Marshal(pkgLastupdate) - if err != nil { - return - } - ioutil.WriteFile(lastupdateFilename, d, os.ModePerm) -} - -func getpathTime(pkgRealpath string) (lastupdate int64, err error) { - fl, err := ioutil.ReadDir(pkgRealpath) - if err != nil { - return lastupdate, err - } - for _, f := range fl { - if lastupdate < f.ModTime().UnixNano() { - lastupdate = f.ModTime().UnixNano() - } - } - return lastupdate, nil -} - -func getRouterDir(pkgRealpath string) string { - dir := filepath.Dir(pkgRealpath) - for { - d := filepath.Join(dir, "routers") - if utils.FileExists(d) { - return d - } - - if r, _ := filepath.Rel(dir, AppPath); r == "." { - return d - } - // Parent dir. - dir = filepath.Dir(dir) - } -} diff --git a/router.go b/router.go deleted file mode 100644 index c7abe38c59..0000000000 --- a/router.go +++ /dev/null @@ -1,1007 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "errors" - "fmt" - "net/http" - "path" - "path/filepath" - "reflect" - "strconv" - "strings" - "sync" - "time" - - beecontext "github.com/astaxie/beego/context" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/toolbox" - "github.com/astaxie/beego/utils" -) - -// default filter execution points -const ( - BeforeStatic = iota - BeforeRouter - BeforeExec - AfterExec - FinishRouter -) - -const ( - routerTypeBeego = iota - routerTypeRESTFul - routerTypeHandler -) - -var ( - // HTTPMETHOD list the supported http methods. - HTTPMETHOD = map[string]bool{ - "GET": true, - "POST": true, - "PUT": true, - "DELETE": true, - "PATCH": true, - "OPTIONS": true, - "HEAD": true, - "TRACE": true, - "CONNECT": true, - "MKCOL": true, - "COPY": true, - "MOVE": true, - "PROPFIND": true, - "PROPPATCH": true, - "LOCK": true, - "UNLOCK": true, - } - // these beego.Controller's methods shouldn't reflect to AutoRouter - exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", - "RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJSON", "ServeJSONP", - "ServeYAML", "ServeXML", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool", - "GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession", - "DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie", - "SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml", - "GetControllerAndAction", "ServeFormatted"} - - urlPlaceholder = "{{placeholder}}" - // DefaultAccessLogFilter will skip the accesslog if return true - DefaultAccessLogFilter FilterHandler = &logFilter{} -) - -// FilterHandler is an interface for -type FilterHandler interface { - Filter(*beecontext.Context) bool -} - -// default log filter static file will not show -type logFilter struct { -} - -func (l *logFilter) Filter(ctx *beecontext.Context) bool { - requestPath := path.Clean(ctx.Request.URL.Path) - if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { - return true - } - for prefix := range BConfig.WebConfig.StaticDir { - if strings.HasPrefix(requestPath, prefix) { - return true - } - } - return false -} - -// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter -func ExceptMethodAppend(action string) { - exceptMethod = append(exceptMethod, action) -} - -// ControllerInfo holds information about the controller. -type ControllerInfo struct { - pattern string - controllerType reflect.Type - methods map[string]string - handler http.Handler - runFunction FilterFunc - routerType int - initialize func() ControllerInterface - methodParams []*param.MethodParam -} - -// ControllerRegister containers registered router rules, controller handlers and filters. -type ControllerRegister struct { - routers map[string]*Tree - enablePolicy bool - policies map[string]*Tree - enableFilter bool - filters [FinishRouter + 1][]*FilterRouter - pool sync.Pool -} - -// NewControllerRegister returns a new ControllerRegister. -func NewControllerRegister() *ControllerRegister { - return &ControllerRegister{ - routers: make(map[string]*Tree), - policies: make(map[string]*Tree), - pool: sync.Pool{ - New: func() interface{} { - return beecontext.NewContext() - }, - }, - } -} - -// Add controller handler and pattern rules to ControllerRegister. -// usage: -// default methods is the same name as method -// Add("/user",&UserController{}) -// Add("/api/list",&RestController{},"*:ListFood") -// Add("/api/create",&RestController{},"post:CreateFood") -// Add("/api/update",&RestController{},"put:UpdateFood") -// Add("/api/delete",&RestController{},"delete:DeleteFood") -// Add("/api",&RestController{},"get,post:ApiFunc" -// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") -func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { - p.addWithMethodParams(pattern, c, nil, mappingMethods...) -} - -func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - methods := make(map[string]string) - if len(mappingMethods) > 0 { - semi := strings.Split(mappingMethods[0], ";") - for _, v := range semi { - colon := strings.Split(v, ":") - if len(colon) != 2 { - panic("method mapping format is invalid") - } - comma := strings.Split(colon[0], ",") - for _, m := range comma { - if m == "*" || HTTPMETHOD[strings.ToUpper(m)] { - if val := reflectVal.MethodByName(colon[1]); val.IsValid() { - methods[strings.ToUpper(m)] = colon[1] - } else { - panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) - } - } else { - panic(v + " is an invalid method mapping. Method doesn't exist " + m) - } - } - } - } - - route := &ControllerInfo{} - route.pattern = pattern - route.methods = methods - route.routerType = routerTypeBeego - route.controllerType = t - route.initialize = func() ControllerInterface { - vc := reflect.New(route.controllerType) - execController, ok := vc.Interface().(ControllerInterface) - if !ok { - panic("controller is not ControllerInterface") - } - - elemVal := reflect.ValueOf(c).Elem() - elemType := reflect.TypeOf(c).Elem() - execElem := reflect.ValueOf(execController).Elem() - - numOfFields := elemVal.NumField() - for i := 0; i < numOfFields; i++ { - fieldType := elemType.Field(i) - elemField := execElem.FieldByName(fieldType.Name) - if elemField.CanSet() { - fieldVal := elemVal.Field(i) - elemField.Set(fieldVal) - } - } - - return execController - } - - route.methodParams = methodParams - if len(methods) == 0 { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } - } -} - -func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { - if !BConfig.RouterCaseSensitive { - pattern = strings.ToLower(pattern) - } - if t, ok := p.routers[method]; ok { - t.AddRouter(pattern, r) - } else { - t := NewTree() - t.AddRouter(pattern, r) - p.routers[method] = t - } -} - -// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller -// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) -func (p *ControllerRegister) Include(cList ...ControllerInterface) { - if BConfig.RunMode == DEV { - skip := make(map[string]bool, 10) - for _, c := range cList { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - wgopath := utils.GetGOPATHs() - if len(wgopath) == 0 { - panic("you are in dev mode. So please set gopath") - } - pkgpath := "" - for _, wg := range wgopath { - wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath())) - if utils.FileExists(wg) { - pkgpath = wg - break - } - } - if pkgpath != "" { - if _, ok := skip[pkgpath]; !ok { - skip[pkgpath] = true - parserPkg(pkgpath, t.PkgPath()) - } - } - } - } - for _, c := range cList { - reflectVal := reflect.ValueOf(c) - t := reflect.Indirect(reflectVal).Type() - key := t.PkgPath() + ":" + t.Name() - if comm, ok := GlobalControllerRouter[key]; ok { - for _, a := range comm { - for _, f := range a.Filters { - p.InsertFilter(f.Pattern, f.Pos, f.Filter, f.ReturnOnOutput, f.ResetParams) - } - - p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) - } - } - } -} - -// Get add get method -// usage: -// Get("/", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) Get(pattern string, f FilterFunc) { - p.AddMethod("get", pattern, f) -} - -// Post add post method -// usage: -// Post("/api", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) Post(pattern string, f FilterFunc) { - p.AddMethod("post", pattern, f) -} - -// Put add put method -// usage: -// Put("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) Put(pattern string, f FilterFunc) { - p.AddMethod("put", pattern, f) -} - -// Delete add delete method -// usage: -// Delete("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) Delete(pattern string, f FilterFunc) { - p.AddMethod("delete", pattern, f) -} - -// Head add head method -// usage: -// Head("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) Head(pattern string, f FilterFunc) { - p.AddMethod("head", pattern, f) -} - -// Patch add patch method -// usage: -// Patch("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) Patch(pattern string, f FilterFunc) { - p.AddMethod("patch", pattern, f) -} - -// Options add options method -// usage: -// Options("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) Options(pattern string, f FilterFunc) { - p.AddMethod("options", pattern, f) -} - -// Any add all method -// usage: -// Any("/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) Any(pattern string, f FilterFunc) { - p.AddMethod("*", pattern, f) -} - -// AddMethod add http method router -// usage: -// AddMethod("get","/api/:id", func(ctx *context.Context){ -// ctx.Output.Body("hello world") -// }) -func (p *ControllerRegister) AddMethod(method, pattern string, f FilterFunc) { - method = strings.ToUpper(method) - if method != "*" && !HTTPMETHOD[method] { - panic("not support http method: " + method) - } - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeRESTFul - route.runFunction = f - methods := make(map[string]string) - if method == "*" { - for val := range HTTPMETHOD { - methods[val] = val - } - } else { - methods[method] = method - } - route.methods = methods - for k := range methods { - if k == "*" { - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } - } else { - p.addToRouter(k, pattern, route) - } - } -} - -// Handler add user defined Handler -func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { - route := &ControllerInfo{} - route.pattern = pattern - route.routerType = routerTypeHandler - route.handler = h - if len(options) > 0 { - if _, ok := options[0].(bool); ok { - pattern = path.Join(pattern, "?:all(.*)") - } - } - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - } -} - -// AddAuto router to ControllerRegister. -// example beego.AddAuto(&MainContorlller{}), -// MainController has method List and Page. -// visit the url /main/list to execute List function -// /main/page to execute Page function. -func (p *ControllerRegister) AddAuto(c ControllerInterface) { - p.AddAutoPrefix("/", c) -} - -// AddAutoPrefix Add auto router to ControllerRegister with prefix. -// example beego.AddAutoPrefix("/admin",&MainContorlller{}), -// MainController has method List and Page. -// visit the url /admin/main/list to execute List function -// /admin/main/page to execute Page function. -func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { - reflectVal := reflect.ValueOf(c) - rt := reflectVal.Type() - ct := reflect.Indirect(reflectVal).Type() - controllerName := strings.TrimSuffix(ct.Name(), "Controller") - for i := 0; i < rt.NumMethod(); i++ { - if !utils.InSlice(rt.Method(i).Name, exceptMethod) { - route := &ControllerInfo{} - route.routerType = routerTypeBeego - route.methods = map[string]string{"*": rt.Method(i).Name} - route.controllerType = ct - pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name), "*") - patternInit := path.Join(prefix, controllerName, rt.Method(i).Name, "*") - patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(rt.Method(i).Name)) - patternFixInit := path.Join(prefix, controllerName, rt.Method(i).Name) - route.pattern = pattern - for m := range HTTPMETHOD { - p.addToRouter(m, pattern, route) - p.addToRouter(m, patternInit, route) - p.addToRouter(m, patternFix, route) - p.addToRouter(m, patternFixInit, route) - } - } - } -} - -// InsertFilter Add a FilterFunc with pattern rule and action constant. -// params is for: -// 1. setting the returnOnOutput value (false allows multiple filters to execute) -// 2. determining whether or not params need to be reset. -func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, params ...bool) error { - mr := &FilterRouter{ - tree: NewTree(), - pattern: pattern, - filterFunc: filter, - returnOnOutput: true, - } - if !BConfig.RouterCaseSensitive { - mr.pattern = strings.ToLower(pattern) - } - - paramsLen := len(params) - if paramsLen > 0 { - mr.returnOnOutput = params[0] - } - if paramsLen > 1 { - mr.resetParams = params[1] - } - mr.tree.AddRouter(pattern, true) - return p.insertFilterRouter(pos, mr) -} - -// add Filter into -func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { - if pos < BeforeStatic || pos > FinishRouter { - return errors.New("can not find your filter position") - } - p.enableFilter = true - p.filters[pos] = append(p.filters[pos], mr) - return nil -} - -// URLFor does another controller handler in this request function. -// it can access any controller method. -func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { - paths := strings.Split(endpoint, ".") - if len(paths) <= 1 { - logs.Warn("urlfor endpoint must like path.controller.method") - return "" - } - if len(values)%2 != 0 { - logs.Warn("urlfor params must key-value pair") - return "" - } - params := make(map[string]string) - if len(values) > 0 { - key := "" - for k, v := range values { - if k%2 == 0 { - key = fmt.Sprint(v) - } else { - params[key] = fmt.Sprint(v) - } - } - } - controllerName := strings.Join(paths[:len(paths)-1], "/") - methodName := paths[len(paths)-1] - for m, t := range p.routers { - ok, url := p.getURL(t, "/", controllerName, methodName, params, m) - if ok { - return url - } - } - return "" -} - -func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) { - for _, subtree := range t.fixrouters { - u := path.Join(url, subtree.prefix) - ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod) - if ok { - return ok, u - } - } - if t.wildcard != nil { - u := path.Join(url, urlPlaceholder) - ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod) - if ok { - return ok, u - } - } - for _, l := range t.leaves { - if c, ok := l.runObject.(*ControllerInfo); ok { - if c.routerType == routerTypeBeego && - strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), controllerName) { - find := false - if HTTPMETHOD[strings.ToUpper(methodName)] { - if len(c.methods) == 0 { - find = true - } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { - find = true - } else if m, ok = c.methods["*"]; ok && m == methodName { - find = true - } - } - if !find { - for m, md := range c.methods { - if (m == "*" || m == httpMethod) && md == methodName { - find = true - } - } - } - if find { - if l.regexps == nil { - if len(l.wildcards) == 0 { - return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params) - } - if len(l.wildcards) == 1 { - if v, ok := params[l.wildcards[0]]; ok { - delete(params, l.wildcards[0]) - return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params) - } - return false, "" - } - if len(l.wildcards) == 3 && l.wildcards[0] == "." { - if p, ok := params[":path"]; ok { - if e, isok := params[":ext"]; isok { - delete(params, ":path") - delete(params, ":ext") - return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params) - } - } - } - canSkip := false - for _, v := range l.wildcards { - if v == ":" { - canSkip = true - continue - } - if u, ok := params[v]; ok { - delete(params, v) - url = strings.Replace(url, urlPlaceholder, u, 1) - } else { - if canSkip { - canSkip = false - continue - } - return false, "" - } - } - return true, url + toURL(params) - } - var i int - var startReg bool - regURL := "" - for _, v := range strings.Trim(l.regexps.String(), "^$") { - if v == '(' { - startReg = true - continue - } else if v == ')' { - startReg = false - if v, ok := params[l.wildcards[i]]; ok { - delete(params, l.wildcards[i]) - regURL = regURL + v - i++ - } else { - break - } - } else if !startReg { - regURL = string(append([]rune(regURL), v)) - } - } - if l.regexps.MatchString(regURL) { - ps := strings.Split(regURL, "/") - for _, p := range ps { - url = strings.Replace(url, urlPlaceholder, p, 1) - } - return true, url + toURL(params) - } - } - } - } - } - - return false, "" -} - -func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { - var preFilterParams map[string]string - for _, filterR := range p.filters[pos] { - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - if filterR.resetParams { - preFilterParams = context.Input.Params() - } - if ok := filterR.ValidRouter(urlPath, context); ok { - filterR.filterFunc(context) - if filterR.resetParams { - context.Input.ResetParams() - for k, v := range preFilterParams { - context.Input.SetParam(k, v) - } - } - } - if filterR.returnOnOutput && context.ResponseWriter.Started { - return true - } - } - return false -} - -// Implement http.Handler interface. -func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - startTime := time.Now() - var ( - runRouter reflect.Type - findRouter bool - runMethod string - methodParams []*param.MethodParam - routerInfo *ControllerInfo - isRunnable bool - ) - context := p.pool.Get().(*beecontext.Context) - context.Reset(rw, r) - - defer p.pool.Put(context) - if BConfig.RecoverFunc != nil { - defer BConfig.RecoverFunc(context) - } - - context.Output.EnableGzip = BConfig.EnableGzip - - if BConfig.RunMode == DEV { - context.Output.Header("Server", BConfig.ServerName) - } - - var urlPath = r.URL.Path - - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } - - // filter wrong http method - if !HTTPMETHOD[r.Method] { - exception("405", context) - goto Admin - } - - // filter for static file - if len(p.filters[BeforeStatic]) > 0 && p.execFilter(context, urlPath, BeforeStatic) { - goto Admin - } - - serverStaticRouter(context) - - if context.ResponseWriter.Started { - findRouter = true - goto Admin - } - - if r.Method != http.MethodGet && r.Method != http.MethodHead { - if BConfig.CopyRequestBody && !context.Input.IsUpload() { - context.Input.CopyBody(BConfig.MaxMemory) - } - context.Input.ParseFormOrMulitForm(BConfig.MaxMemory) - } - - // session init - if BConfig.WebConfig.Session.SessionOn { - var err error - context.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) - if err != nil { - logs.Error(err) - exception("503", context) - goto Admin - } - defer func() { - if context.Input.CruSession != nil { - context.Input.CruSession.SessionRelease(rw) - } - }() - } - if len(p.filters[BeforeRouter]) > 0 && p.execFilter(context, urlPath, BeforeRouter) { - goto Admin - } - // User can define RunController and RunMethod in filter - if context.Input.RunController != nil && context.Input.RunMethod != "" { - findRouter = true - runMethod = context.Input.RunMethod - runRouter = context.Input.RunController - } else { - routerInfo, findRouter = p.FindRouter(context) - } - - //if no matches to url, throw a not found exception - if !findRouter { - exception("404", context) - goto Admin - } - if splat := context.Input.Param(":splat"); splat != "" { - for k, v := range strings.Split(splat, "/") { - context.Input.SetParam(strconv.Itoa(k), v) - } - } - - //execute middleware filters - if len(p.filters[BeforeExec]) > 0 && p.execFilter(context, urlPath, BeforeExec) { - goto Admin - } - - //check policies - if p.execPolicy(context, urlPath) { - goto Admin - } - - if routerInfo != nil { - //store router pattern into context - context.Input.SetData("RouterPattern", routerInfo.pattern) - if routerInfo.routerType == routerTypeRESTFul { - if _, ok := routerInfo.methods[r.Method]; ok { - isRunnable = true - routerInfo.runFunction(context) - } else { - exception("405", context) - goto Admin - } - } else if routerInfo.routerType == routerTypeHandler { - isRunnable = true - routerInfo.handler.ServeHTTP(rw, r) - } else { - runRouter = routerInfo.controllerType - methodParams = routerInfo.methodParams - method := r.Method - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost { - method = http.MethodPut - } - if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodDelete { - method = http.MethodDelete - } - if m, ok := routerInfo.methods[method]; ok { - runMethod = m - } else if m, ok = routerInfo.methods["*"]; ok { - runMethod = m - } else { - runMethod = method - } - } - } - - // also defined runRouter & runMethod from filter - if !isRunnable { - //Invoke the request handler - var execController ControllerInterface - if routerInfo != nil && routerInfo.initialize != nil { - execController = routerInfo.initialize() - } else { - vc := reflect.New(runRouter) - var ok bool - execController, ok = vc.Interface().(ControllerInterface) - if !ok { - panic("controller is not ControllerInterface") - } - } - - //call the controller init function - execController.Init(context, runRouter.Name(), runMethod, execController) - - //call prepare function - execController.Prepare() - - //if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf - if BConfig.WebConfig.EnableXSRF { - execController.XSRFToken() - if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || - (r.Method == http.MethodPost && (context.Input.Query("_method") == http.MethodDelete || context.Input.Query("_method") == http.MethodPut)) { - execController.CheckXSRFCookie() - } - } - - execController.URLMapping() - - if !context.ResponseWriter.Started { - //exec main logic - switch runMethod { - case http.MethodGet: - execController.Get() - case http.MethodPost: - execController.Post() - case http.MethodDelete: - execController.Delete() - case http.MethodPut: - execController.Put() - case http.MethodHead: - execController.Head() - case http.MethodPatch: - execController.Patch() - case http.MethodOptions: - execController.Options() - case http.MethodTrace: - execController.Trace() - default: - if !execController.HandlerFunc(runMethod) { - vc := reflect.ValueOf(execController) - method := vc.MethodByName(runMethod) - in := param.ConvertParams(methodParams, method.Type(), context) - out := method.Call(in) - - //For backward compatibility we only handle response if we had incoming methodParams - if methodParams != nil { - p.handleParamResponse(context, execController, out) - } - } - } - - //render template - if !context.ResponseWriter.Started && context.Output.Status == 0 { - if BConfig.WebConfig.AutoRender { - if err := execController.Render(); err != nil { - logs.Error(err) - } - } - } - } - - // finish all runRouter. release resource - execController.Finish() - } - - //execute middleware filters - if len(p.filters[AfterExec]) > 0 && p.execFilter(context, urlPath, AfterExec) { - goto Admin - } - - if len(p.filters[FinishRouter]) > 0 && p.execFilter(context, urlPath, FinishRouter) { - goto Admin - } - -Admin: - //admin module record QPS - - statusCode := context.ResponseWriter.Status - if statusCode == 0 { - statusCode = 200 - } - - LogAccess(context, &startTime, statusCode) - - timeDur := time.Since(startTime) - context.ResponseWriter.Elapsed = timeDur - if BConfig.Listen.EnableAdmin { - pattern := "" - if routerInfo != nil { - pattern = routerInfo.pattern - } - - if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { - routerName := "" - if runRouter != nil { - routerName = runRouter.Name() - } - go toolbox.StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) - } - } - - if BConfig.RunMode == DEV && !BConfig.Log.AccessLogs { - match := map[bool]string{true: "match", false: "nomatch"} - devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", - context.Input.IP(), - logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), - timeDur.String(), - match[findRouter], - logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(), - r.URL.Path) - if routerInfo != nil { - devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern) - } - - logs.Debug(devInfo) - } - // Call WriteHeader if status code has been set changed - if context.Output.Status != 0 { - context.ResponseWriter.WriteHeader(context.Output.Status) - } -} - -func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { - //looping in reverse order for the case when both error and value are returned and error sets the response status code - for i := len(results) - 1; i >= 0; i-- { - result := results[i] - if result.Kind() != reflect.Interface || !result.IsNil() { - resultValue := result.Interface() - context.RenderMethodResult(resultValue) - } - } - if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 { - context.Output.SetStatus(200) - } -} - -// FindRouter Find Router info for URL -func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { - var urlPath = context.Input.URL() - if !BConfig.RouterCaseSensitive { - urlPath = strings.ToLower(urlPath) - } - httpMethod := context.Input.Method() - if t, ok := p.routers[httpMethod]; ok { - runObject := t.Match(urlPath, context) - if r, ok := runObject.(*ControllerInfo); ok { - return r, true - } - } - return -} - -func toURL(params map[string]string) string { - if len(params) == 0 { - return "" - } - u := "?" - for k, v := range params { - u += k + "=" + v + "&" - } - return strings.TrimRight(u, "&") -} - -func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { - //Skip logging if AccessLogs config is false - if !BConfig.Log.AccessLogs { - return - } - //Skip logging static requests unless EnableStaticLogs config is true - if !BConfig.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { - return - } - var ( - requestTime time.Time - elapsedTime time.Duration - r = ctx.Request - ) - if startTime != nil { - requestTime = *startTime - elapsedTime = time.Since(*startTime) - } - record := &logs.AccessLogRecord{ - RemoteAddr: ctx.Input.IP(), - RequestTime: requestTime, - RequestMethod: r.Method, - Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), - ServerProtocol: r.Proto, - Host: r.Host, - Status: statusCode, - ElapsedTime: elapsedTime, - HTTPReferrer: r.Header.Get("Referer"), - HTTPUserAgent: r.Header.Get("User-Agent"), - RemoteUser: r.Header.Get("Remote-User"), - BodyBytesSent: 0, //@todo this one is missing! - } - logs.AccessLog(record, BConfig.Log.AccessLogsFormat) -} diff --git a/scripts/orm_docker_compose.yaml b/scripts/orm_docker_compose.yaml new file mode 100644 index 0000000000..0ccf93d049 --- /dev/null +++ b/scripts/orm_docker_compose.yaml @@ -0,0 +1,124 @@ +# Copyright 2022 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: "3.8" +services: + # mysql5.7 + mysql5: + container_name: "beego-mysql5" + image: mysql:5.7.30 + ports: + - "13306:3306" + environment: + - MYSQL_ROOT_PASSWORD=1q2w3e + - MYSQL_DATABASE=orm_test + - MYSQL_USER=beego + - MYSQL_PASSWORD=test + + # mysql8.0 + mysql8: + container_name: "beego-mysql8" + image: mysql:8.0 + ports: + - "23306:3306" + environment: + - MYSQL_ROOT_PASSWORD=1q2w3e + - MYSQL_DATABASE=orm_test + - MYSQL_USER=beego + - MYSQL_PASSWORD=test + + # postgresql + postgresql: + container_name: "beego-postgresql" + image: bitnami/postgresql:latest + ports: + - "5432:5432" + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=orm_test + + # tidb + pd0: + image: pingcap/pd:latest + ports: + - "2379:2379" + volumes: + - ./config/pd.toml:/pd.toml:ro + - ./data:/data + - ./logs:/logs + command: + - --name=pd0 + - --client-urls=http://0.0.0.0:2379 + - --peer-urls=http://0.0.0.0:2380 + - --advertise-client-urls=http://pd0:2379 + - --advertise-peer-urls=http://pd0:2380 + - --initial-cluster=pd0=http://pd0:2380 + - --data-dir=/data/pd0 + - --config=/pd.toml + - --log-file=/logs/pd0.log + restart: on-failure + + tikv0: + image: pingcap/tikv:latest + volumes: + - ./config/tikv.toml:/tikv.toml:ro + - ./data:/data + - ./logs:/logs + command: + - --addr=0.0.0.0:20160 + - --advertise-addr=tikv0:20160 + - --data-dir=/data/tikv0 + - --pd=pd0:2379 + - --config=/tikv.toml + - --log-file=/logs/tikv0.log + depends_on: + - "pd0" + restart: on-failure + + tikv1: + image: pingcap/tikv:latest + volumes: + - ./config/tikv.toml:/tikv.toml:ro + - ./data:/data + - ./logs:/logs + command: + - --addr=0.0.0.0:20160 + - --advertise-addr=tikv1:20160 + - --data-dir=/data/tikv1 + - --pd=pd0:2379 + - --config=/tikv.toml + - --log-file=/logs/tikv1.log + depends_on: + - "pd0" + restart: on-failure + + tidb: + image: pingcap/tidb:latest + ports: + - "4000:4000" + - "10080:10080" + volumes: + - ./config/tidb.toml:/tidb.toml:ro + - ./logs:/logs + command: + - --store=tikv + - --path=pd0:2379 + - --config=/tidb.toml + - --log-file=/logs/tidb.log + - --advertise-address=tidb + depends_on: + - "tikv0" + - "tikv1" + restart: on-failure diff --git a/scripts/prepare_etcd.sh b/scripts/prepare_etcd.sh new file mode 100644 index 0000000000..d34c05a3ed --- /dev/null +++ b/scripts/prepare_etcd.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +etcdctl put current.float 1.23 +etcdctl put current.bool true +etcdctl put current.int 11 +etcdctl put current.string hello +etcdctl put current.serialize.name test +etcdctl put sub.sub.key1 sub.sub.key \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 0000000000..977055a3e4 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" up -d + +export ORM_DRIVER=mysql +export TZ=UTC +export ORM_SOURCE="beego:test@tcp(localhost:13306)/orm_test?charset=utf8" + +# wait for services in images ready +sleep 5 + +go test "$(pwd)/..." + +# clear all container +docker-compose -f "$(pwd)/scripts/test_docker_compose.yaml" down + + diff --git a/scripts/test_docker_compose.yaml b/scripts/test_docker_compose.yaml new file mode 100644 index 0000000000..066fdf9742 --- /dev/null +++ b/scripts/test_docker_compose.yaml @@ -0,0 +1,56 @@ +services: + redis: + container_name: "beego-redis" + image: redis + environment: + - ALLOW_EMPTY_PASSWORD=yes + ports: + - "6379:6379" + + mysql: + container_name: "beego-mysql" + image: mysql:5.7.30 + ports: + - "13306:3306" + environment: + - MYSQL_ROOT_PASSWORD=1q2w3e + - MYSQL_DATABASE=orm_test + - MYSQL_USER=beego + - MYSQL_PASSWORD=test + + postgresql: + container_name: "beego-postgresql" + image: bitnami/postgresql:latest + ports: + - "5432:5432" + environment: + - ALLOW_EMPTY_PASSWORD=yes + ssdb: + container_name: "beego-ssdb" + image: tsl0922/ssdb + environment: + - SSDB_PORT=8888 + ports: + - "8888:8888" + memcache: + container_name: "beego-memcache" + image: memcached + ports: + - "11211:11211" + etcd: + command: > + sh -c " + etcdctl put current.float 1.23 + && etcdctl put current.bool true + && etcdctl put current.int 11 + && etcdctl put current.string hello + && etcdctl put current.serialize.name test + " + container_name: "beego-etcd" + environment: + - ALLOW_NONE_AUTHENTICATION=yes +# - ETCD_ADVERTISE_CLIENT_URLS=http://etcd:2379 + image: bitnami/etcd + ports: + - "2379:2379" + - "2380:2380" \ No newline at end of file diff --git a/vendor/gopkg.in/yaml.v2/NOTICE b/server/web/LICENSE similarity index 93% rename from vendor/gopkg.in/yaml.v2/NOTICE rename to server/web/LICENSE index 866d74a7ad..26050108ef 100644 --- a/vendor/gopkg.in/yaml.v2/NOTICE +++ b/server/web/LICENSE @@ -1,4 +1,4 @@ -Copyright 2011-2016 Canonical Ltd. +Copyright 2014 astaxie Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/server/web/admin.go b/server/web/admin.go new file mode 100644 index 0000000000..10c80d6d73 --- /dev/null +++ b/server/web/admin.go @@ -0,0 +1,123 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "fmt" + "net" + "net/http" + "reflect" + "time" + + "github.com/beego/beego/v2/core/logs" +) + +// BeeAdminApp is the default adminApp used by admin module. +var beeAdminApp *adminApp + +// FilterMonitorFunc is default monitor filter when admin module is enable. +// if this func returns, admin module records qps for this request by condition of this function logic. +// usage: +// +// func MyFilterMonitor(method, requestPath string, t time.Duration, pattern string, statusCode int) bool { +// if method == "POST" { +// return false +// } +// if t.Nanoseconds() < 100 { +// return false +// } +// if strings.HasPrefix(requestPath, "/astaxie") { +// return false +// } +// return true +// } +// beego.FilterMonitorFunc = MyFilterMonitor. +var FilterMonitorFunc func(string, string, time.Duration, string, int) bool + +func init() { + FilterMonitorFunc = func(string, string, time.Duration, string, int) bool { return true } +} + +func list(root string, p interface{}, m M) { + pt := reflect.TypeOf(p) + pv := reflect.ValueOf(p) + if pt.Kind() == reflect.Ptr { + pt = pt.Elem() + pv = pv.Elem() + } + for i := 0; i < pv.NumField(); i++ { + var key string + if root == "" { + key = pt.Field(i).Name + } else { + key = root + "." + pt.Field(i).Name + } + if pv.Field(i).Kind() == reflect.Struct { + list(key, pv.Field(i).Interface(), m) + } else { + m[key] = pv.Field(i).Interface() + } + } +} + +func writeJSON(rw http.ResponseWriter, jsonData []byte) { + rw.Header().Set("Content-Type", "application/json") + rw.Write(jsonData) +} + +// adminApp is an http.HandlerFunc map used as beeAdminApp. +type adminApp struct { + *HttpServer +} + +// Run start Beego admin +func (admin *adminApp) Run() { + logs.Debug("now we don't start tasks here, if you use task module," + + " please invoke task.StartTask, or task will not be executed") + addr := BConfig.Listen.AdminAddr + if BConfig.Listen.AdminPort != 0 { + addr = net.JoinHostPort(BConfig.Listen.AdminAddr, fmt.Sprintf("%d", BConfig.Listen.AdminPort)) + } + logs.Info("Admin server Running on %s", addr) + admin.HttpServer.Run(addr) +} + +func registerAdmin() error { + if BConfig.Listen.EnableAdmin { + + c := &adminController{ + servers: make([]*HttpServer, 0, 2), + } + + // copy config to avoid conflict + adminCfg := *BConfig + adminCfg.Listen.EnableHTTPS = false + adminCfg.Listen.EnableMutualHTTPS = false + beeAdminApp = &adminApp{ + HttpServer: NewHttpServerWithCfg(&adminCfg), + } + // keep in mind that all data should be html escaped to avoid XSS attack + beeAdminApp.Router("/", c, "get:AdminIndex") + beeAdminApp.Router("/qps", c, "get:QpsIndex") + beeAdminApp.Router("/prof", c, "get:ProfIndex") + beeAdminApp.Router("/healthcheck", c, "get:Healthcheck") + beeAdminApp.Router("/task", c, "get:TaskStatus") + beeAdminApp.Router("/listconf", c, "get:ListConf") + beeAdminApp.Router("/metrics", c, "get:PrometheusMetrics") + + go beeAdminApp.Run() + } + return nil +} diff --git a/server/web/admin_controller.go b/server/web/admin_controller.go new file mode 100644 index 0000000000..ce77ecdb95 --- /dev/null +++ b/server/web/admin_controller.go @@ -0,0 +1,293 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "strconv" + "text/template" + + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/beego/beego/v2/core/admin" +) + +type adminController struct { + Controller + servers []*HttpServer +} + +func (a *adminController) registerHttpServer(svr *HttpServer) { + a.servers = append(a.servers, svr) +} + +// ProfIndex is a http.Handler for showing profile command. +// it's in url pattern "/prof" in admin module. +func (a *adminController) ProfIndex() { + rw, r := a.Ctx.ResponseWriter, a.Ctx.Request + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + return + } + + var ( + format = r.Form.Get("format") + data = make(map[interface{}]interface{}) + result bytes.Buffer + ) + admin.ProcessInput(command, &result) + data["Content"] = template.HTMLEscapeString(result.String()) + + if format == "json" && command == "gc summary" { + dataJSON, err := json.Marshal(data) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(rw, dataJSON) + return + } + + data["Title"] = template.HTMLEscapeString(command) + defaultTpl := defaultScriptsTpl + if command == "gc summary" { + defaultTpl = gcAjaxTpl + } + writeTemplate(rw, data, profillingTpl, defaultTpl) +} + +func (a *adminController) PrometheusMetrics() { + promhttp.Handler().ServeHTTP(a.Ctx.ResponseWriter, a.Ctx.Request) +} + +// TaskStatus is a http.Handler with running task status (task name, status and the last execution). +// it's in "/task" pattern in admin module. +func (a *adminController) TaskStatus() { + rw, req := a.Ctx.ResponseWriter, a.Ctx.Request + + data := make(map[interface{}]interface{}) + + // Run Task + req.ParseForm() + taskname := req.Form.Get("taskname") + if taskname != "" { + cmd := admin.GetCommand("task", "run") + res := cmd.Execute(taskname) + if res.IsSuccess() { + data["Message"] = []string{ + "success", + template.HTMLEscapeString(fmt.Sprintf("%s run success,Now the Status is
%s", + taskname, res.Content.(string))), + } + } else { + data["Message"] = []string{"error", template.HTMLEscapeString(fmt.Sprintf("%s", res.Error))} + } + } + + // List Tasks + content := make(M) + resultList := admin.GetCommand("task", "list").Execute().Content.([][]string) + fields := []string{ + "Task Name", + "Task Spec", + "Task Status", + "Last Time", + "", + } + + content["Fields"] = fields + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Tasks" + writeTemplate(rw, data, tasksTpl, defaultScriptsTpl) +} + +func (a *adminController) AdminIndex() { + // AdminIndex is the default http.Handler for admin module. + // it matches url pattern "/". + writeTemplate(a.Ctx.ResponseWriter, map[interface{}]interface{}{}, indexTpl, defaultScriptsTpl) +} + +// Healthcheck is a http.Handler calling health checking and showing the result. +// it's in "/healthcheck" pattern in admin module. +func (a *adminController) Healthcheck() { + heathCheck(a.Ctx.ResponseWriter, a.Ctx.Request) +} + +func heathCheck(rw http.ResponseWriter, r *http.Request) { + var ( + result []string + data = make(map[interface{}]interface{}) + resultList = new([][]string) + content = M{ + "Fields": []string{"Name", "Message", "Status"}, + } + ) + + for name, h := range admin.AdminCheckList { + if err := h.Check(); err != nil { + result = []string{ + "error", + template.HTMLEscapeString(name), + template.HTMLEscapeString(err.Error()), + } + } else { + result = []string{ + "success", + template.HTMLEscapeString(name), + "OK", + } + } + *resultList = append(*resultList, result) + } + + queryParams := r.URL.Query() + jsonFlag := queryParams.Get("json") + shouldReturnJSON, _ := strconv.ParseBool(jsonFlag) + + if shouldReturnJSON { + response := buildHealthCheckResponseList(resultList) + jsonResponse, err := json.Marshal(response) + + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + writeJSON(rw, jsonResponse) + } + return + } + + content["Data"] = resultList + data["Content"] = content + data["Title"] = "Health Check" + + writeTemplate(rw, data, healthCheckTpl, defaultScriptsTpl) +} + +// QpsIndex is the http.Handler for writing qps statistics map result info in http.ResponseWriter. +// it's registered with url pattern "/qps" in admin module. +func (a *adminController) QpsIndex() { + data := make(map[interface{}]interface{}) + data["Content"] = StatisticsMap.GetMap() + + // do html escape before display path, avoid xss + if content, ok := (data["Content"]).(M); ok { + if resultLists, ok := (content["Data"]).([][]string); ok { + for i := range resultLists { + if len(resultLists[i]) > 0 { + resultLists[i][0] = template.HTMLEscapeString(resultLists[i][0]) + } + } + } + } + writeTemplate(a.Ctx.ResponseWriter, data, qpsTpl, defaultScriptsTpl) +} + +// ListConf is the http.Handler of displaying all beego configuration values as key/value pair. +// it's registered with url pattern "/listconf" in admin module. +func (a *adminController) ListConf() { + rw := a.Ctx.ResponseWriter + r := a.Ctx.Request + r.ParseForm() + command := r.Form.Get("command") + if command == "" { + rw.Write([]byte("command not support")) + return + } + + data := make(map[interface{}]interface{}) + switch command { + case "conf": + m := make(M) + list("BConfig", BConfig, m) + m["appConfigPath"] = template.HTMLEscapeString(appConfigPath) + m["appConfigProvider"] = template.HTMLEscapeString(appConfigProvider) + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + tmpl = template.Must(tmpl.Parse(configTpl)) + tmpl = template.Must(tmpl.Parse(defaultScriptsTpl)) + + data["Content"] = m + + tmpl.Execute(rw, data) + + case "router": + content := BeeApp.PrintTree() + content["Fields"] = []string{ + "Router Pattern", + "Methods", + "Controller", + } + data["Content"] = content + data["Title"] = "Routers" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + case "filter": + content := M{ + "Fields": []string{ + "Router Pattern", + "Filter Function", + }, + } + + filterTypeData := BeeApp.reportFilter() + + filterTypes := make([]string, 0, len(filterTypeData)) + for k := range filterTypeData { + filterTypes = append(filterTypes, k) + } + + content["Data"] = filterTypeData + content["Methods"] = filterTypes + + data["Content"] = content + data["Title"] = "Filters" + writeTemplate(rw, data, routerAndFilterTpl, defaultScriptsTpl) + default: + rw.Write([]byte("command not support")) + } +} + +func writeTemplate(rw http.ResponseWriter, data map[interface{}]interface{}, tpls ...string) { + tmpl := template.Must(template.New("dashboard").Parse(dashboardTpl)) + for _, tpl := range tpls { + tmpl = template.Must(tmpl.Parse(tpl)) + } + tmpl.Execute(rw, data) +} + +func buildHealthCheckResponseList(healthCheckResults *[][]string) []map[string]interface{} { + response := make([]map[string]interface{}, len(*healthCheckResults)) + + for i, healthCheckResult := range *healthCheckResults { + currentResultMap := make(map[string]interface{}) + + currentResultMap["name"] = healthCheckResult[0] + currentResultMap["message"] = healthCheckResult[1] + currentResultMap["status"] = healthCheckResult[2] + + response[i] = currentResultMap + } + + return response +} + +// PrintTree print all routers +// Deprecated using BeeApp directly +func PrintTree() M { + return BeeApp.PrintTree() +} diff --git a/server/web/admin_test.go b/server/web/admin_test.go new file mode 100644 index 0000000000..57aa0fe6ea --- /dev/null +++ b/server/web/admin_test.go @@ -0,0 +1,238 @@ +package web + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/core/admin" +) + +type SampleDatabaseCheck struct{} + +type SampleCacheCheck struct{} + +func (dc *SampleDatabaseCheck) Check() error { + return nil +} + +func (cc *SampleCacheCheck) Check() error { + return errors.New("no cache detected") +} + +func TestList_01(t *testing.T) { + m := make(M) + list("BConfig", BConfig, m) + om := oldMap() + for k, v := range om { + if fmt.Sprint(m[k]) != fmt.Sprint(v) { + t.Log(k, "old-key", v, "new-key", m[k]) + t.FailNow() + } + } +} + +func oldMap() M { + m := make(M) + m["BConfig.AppName"] = BConfig.AppName + m["BConfig.RunMode"] = BConfig.RunMode + m["BConfig.RouterCaseSensitive"] = BConfig.RouterCaseSensitive + m["BConfig.ServerName"] = BConfig.ServerName + m["BConfig.RecoverPanic"] = BConfig.RecoverPanic + m["BConfig.CopyRequestBody"] = BConfig.CopyRequestBody + m["BConfig.EnableGzip"] = BConfig.EnableGzip + m["BConfig.MaxMemory"] = BConfig.MaxMemory + m["BConfig.EnableErrorsShow"] = BConfig.EnableErrorsShow + m["BConfig.Listen.Graceful"] = BConfig.Listen.Graceful + m["BConfig.Listen.ServerTimeOut"] = BConfig.Listen.ServerTimeOut + m["BConfig.Listen.ListenTCP4"] = BConfig.Listen.ListenTCP4 + m["BConfig.Listen.EnableHTTP"] = BConfig.Listen.EnableHTTP + m["BConfig.Listen.HTTPAddr"] = BConfig.Listen.HTTPAddr + m["BConfig.Listen.HTTPPort"] = BConfig.Listen.HTTPPort + m["BConfig.Listen.EnableHTTPS"] = BConfig.Listen.EnableHTTPS + m["BConfig.Listen.HTTPSAddr"] = BConfig.Listen.HTTPSAddr + m["BConfig.Listen.HTTPSPort"] = BConfig.Listen.HTTPSPort + m["BConfig.Listen.HTTPSCertFile"] = BConfig.Listen.HTTPSCertFile + m["BConfig.Listen.HTTPSKeyFile"] = BConfig.Listen.HTTPSKeyFile + m["BConfig.Listen.EnableAdmin"] = BConfig.Listen.EnableAdmin + m["BConfig.Listen.AdminAddr"] = BConfig.Listen.AdminAddr + m["BConfig.Listen.AdminPort"] = BConfig.Listen.AdminPort + m["BConfig.Listen.EnableFcgi"] = BConfig.Listen.EnableFcgi + m["BConfig.Listen.EnableStdIo"] = BConfig.Listen.EnableStdIo + m["BConfig.WebConfig.AutoRender"] = BConfig.WebConfig.AutoRender + m["BConfig.WebConfig.EnableDocs"] = BConfig.WebConfig.EnableDocs + m["BConfig.WebConfig.FlashName"] = BConfig.WebConfig.FlashName + m["BConfig.WebConfig.FlashSeparator"] = BConfig.WebConfig.FlashSeparator + m["BConfig.WebConfig.DirectoryIndex"] = BConfig.WebConfig.DirectoryIndex + m["BConfig.WebConfig.StaticDir"] = BConfig.WebConfig.StaticDir + m["BConfig.WebConfig.StaticExtensionsToGzip"] = BConfig.WebConfig.StaticExtensionsToGzip + m["BConfig.WebConfig.StaticCacheFileSize"] = BConfig.WebConfig.StaticCacheFileSize + m["BConfig.WebConfig.StaticCacheFileNum"] = BConfig.WebConfig.StaticCacheFileNum + m["BConfig.WebConfig.TemplateLeft"] = BConfig.WebConfig.TemplateLeft + m["BConfig.WebConfig.TemplateRight"] = BConfig.WebConfig.TemplateRight + m["BConfig.WebConfig.ViewsPath"] = BConfig.WebConfig.ViewsPath + m["BConfig.WebConfig.EnableXSRF"] = BConfig.WebConfig.EnableXSRF + m["BConfig.WebConfig.XSRFExpire"] = BConfig.WebConfig.XSRFExpire + m["BConfig.WebConfig.Session.SessionOn"] = BConfig.WebConfig.Session.SessionOn + m["BConfig.WebConfig.Session.SessionProvider"] = BConfig.WebConfig.Session.SessionProvider + m["BConfig.WebConfig.Session.SessionName"] = BConfig.WebConfig.Session.SessionName + m["BConfig.WebConfig.Session.SessionGCMaxLifetime"] = BConfig.WebConfig.Session.SessionGCMaxLifetime + m["BConfig.WebConfig.Session.SessionProviderConfig"] = BConfig.WebConfig.Session.SessionProviderConfig + m["BConfig.WebConfig.Session.SessionCookieLifeTime"] = BConfig.WebConfig.Session.SessionCookieLifeTime + m["BConfig.WebConfig.Session.SessionAutoSetCookie"] = BConfig.WebConfig.Session.SessionAutoSetCookie + m["BConfig.WebConfig.Session.SessionDomain"] = BConfig.WebConfig.Session.SessionDomain + m["BConfig.WebConfig.Session.SessionDisableHTTPOnly"] = BConfig.WebConfig.Session.SessionDisableHTTPOnly + m["BConfig.Log.AccessLogs"] = BConfig.Log.AccessLogs + m["BConfig.Log.EnableStaticLogs"] = BConfig.Log.EnableStaticLogs + m["BConfig.Log.AccessLogsFormat"] = BConfig.Log.AccessLogsFormat + m["BConfig.Log.FileLineNum"] = BConfig.Log.FileLineNum + m["BConfig.Log.Outputs"] = BConfig.Log.Outputs + return m +} + +func TestWriteJSON(t *testing.T) { + w := httptest.NewRecorder() + originalBody := []int{1, 2, 3} + + res, _ := json.Marshal(originalBody) + + writeJSON(w, res) + + decodedBody := []int{} + err := json.NewDecoder(w.Body).Decode(&decodedBody) + if err != nil { + t.Fatal("Could not decode response body into slice.") + } + + for i := range decodedBody { + if decodedBody[i] != originalBody[i] { + t.Fatalf("Expected %d but got %d in decoded body slice", originalBody[i], decodedBody[i]) + } + } +} + +func TestHealthCheckHandlerDefault(t *testing.T) { + endpointPath := "/healthcheck" + + admin.AddHealthCheck("database", &SampleDatabaseCheck{}) + admin.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", endpointPath, nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(heathCheck) + + handler.ServeHTTP(w, req) + + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + if !strings.Contains(w.Body.String(), "database") { + t.Errorf("Expected 'database' in generated template.") + } +} + +func TestBuildHealthCheckResponseList(t *testing.T) { + healthCheckResults := [][]string{ + { + "error", + "Database", + "Error occurred whie starting the db", + }, + { + "success", + "Cache", + "Cache started successfully", + }, + } + + responseList := buildHealthCheckResponseList(&healthCheckResults) + + if len(responseList) != len(healthCheckResults) { + t.Errorf("invalid response map length: got %d want %d", + len(responseList), len(healthCheckResults)) + } + + responseFields := []string{"name", "message", "status"} + + for _, response := range responseList { + for _, field := range responseFields { + _, ok := response[field] + if !ok { + t.Errorf("expected %s to be in the response %v", field, response) + } + } + } +} + +func TestHealthCheckHandlerReturnsJSON(t *testing.T) { + admin.AddHealthCheck("database", &SampleDatabaseCheck{}) + admin.AddHealthCheck("cache", &SampleCacheCheck{}) + + req, err := http.NewRequest("GET", "/healthcheck?json=true", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + + handler := http.HandlerFunc(heathCheck) + + handler.ServeHTTP(w, req) + if status := w.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + decodedResponseBody := []map[string]interface{}{} + expectedResponseBody := []map[string]interface{}{} + + expectedJSONString := []byte(` + [ + { + "message":"database", + "name":"success", + "status":"OK" + }, + { + "message":"cache", + "name":"error", + "status":"no cache detected" + } + ] + `) + + json.Unmarshal(expectedJSONString, &expectedResponseBody) + + json.Unmarshal(w.Body.Bytes(), &decodedResponseBody) + + if len(expectedResponseBody) != len(decodedResponseBody) { + t.Errorf("invalid response map length: got %d want %d", + len(decodedResponseBody), len(expectedResponseBody)) + } + assert.Equal(t, len(expectedResponseBody), len(decodedResponseBody)) + assert.Equal(t, 2, len(decodedResponseBody)) + + var database, cache map[string]interface{} + if decodedResponseBody[0]["message"] == "database" { + database = decodedResponseBody[0] + cache = decodedResponseBody[1] + } else { + database = decodedResponseBody[1] + cache = decodedResponseBody[0] + } + + assert.Equal(t, expectedResponseBody[0], database) + assert.Equal(t, expectedResponseBody[1], cache) +} diff --git a/adminui.go b/server/web/adminui.go similarity index 98% rename from adminui.go rename to server/web/adminui.go index cdcdef33f2..75a8badc86 100644 --- a/adminui.go +++ b/server/web/adminui.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web var indexTpl = ` {{define "content"}} @@ -21,10 +21,10 @@ var indexTpl = ` For detail usage please check our document:

-Toolbox +Toolbox

-Live Monitor +Live Monitor

{{.Content}} {{end}}` diff --git a/beego.go b/server/web/beego.go similarity index 65% rename from beego.go rename to server/web/beego.go index 7ebea8a23c..5c96c350f1 100644 --- a/beego.go +++ b/server/web/beego.go @@ -12,19 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "os" "path/filepath" - "strconv" - "strings" + "sync" ) const ( - // VERSION represent beego web framework version. - VERSION = "1.11.2" - // DEV is for develop DEV = "dev" // PROD is for production @@ -37,9 +33,7 @@ type M map[string]interface{} // Hook function to run type hookfunc func() error -var ( - hooks = make([]hookfunc, 0) //hook function slice to store the hookfunc -) +var hooks = make([]hookfunc, 0) // hook function slice to store the hookfunc // AddAPPStartHook is used to register the hookfunc // The hookfuncs will run in beego.Run() @@ -54,56 +48,40 @@ func AddAPPStartHook(hf ...hookfunc) { // beego.Run(":8089") // beego.Run("127.0.0.1:8089") func Run(params ...string) { - - initBeforeHTTPRun() - if len(params) > 0 && params[0] != "" { - strs := strings.Split(params[0], ":") - if len(strs) > 0 && strs[0] != "" { - BConfig.Listen.HTTPAddr = strs[0] - } - if len(strs) > 1 && strs[1] != "" { - BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) - } - - BConfig.Listen.Domains = params + BeeApp.Run(params[0]) + } else { + BeeApp.Run("") } - - BeeApp.Run() } // RunWithMiddleWares Run beego application with middlewares. func RunWithMiddleWares(addr string, mws ...MiddleWare) { - initBeforeHTTPRun() - - strs := strings.Split(addr, ":") - if len(strs) > 0 && strs[0] != "" { - BConfig.Listen.HTTPAddr = strs[0] - BConfig.Listen.Domains = []string{strs[0]} - } - if len(strs) > 1 && strs[1] != "" { - BConfig.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) - } - - BeeApp.Run(mws...) + BeeApp.Run(addr, mws...) } -func initBeforeHTTPRun() { - //init hooks - AddAPPStartHook( - registerMime, - registerDefaultErrorHandler, - registerSession, - registerTemplate, - registerAdmin, - registerGzip, - ) +var initHttpOnce sync.Once - for _, hk := range hooks { - if err := hk(); err != nil { - panic(err) +// TODO move to module init function +func initBeforeHTTPRun() { + initHttpOnce.Do(func() { + // init hooks + AddAPPStartHook( + registerMime, + registerDefaultErrorHandler, + registerSession, + registerTemplate, + registerAdmin, + registerGzip, + // registerCommentRouter, + ) + + for _, hk := range hooks { + if err := hk(); err != nil { + panic(err) + } } - } + }) } // TestBeegoInit is for test package init diff --git a/utils/captcha/LICENSE b/server/web/captcha/LICENSE similarity index 100% rename from utils/captcha/LICENSE rename to server/web/captcha/LICENSE diff --git a/utils/captcha/README.md b/server/web/captcha/README.md similarity index 84% rename from utils/captcha/README.md rename to server/web/captcha/README.md index dbc2026b1e..07a4dc4da9 100644 --- a/utils/captcha/README.md +++ b/server/web/captcha/README.md @@ -6,9 +6,9 @@ an example for use captcha package controllers import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/cache" - "github.com/astaxie/beego/utils/captcha" + "github.com/beego/beego/v2" + "github.com/beego/beego/v2/client/cache" + "github.com/beego/beego/v2/server/web/captcha" ) var cpt *captcha.Captcha diff --git a/utils/captcha/captcha.go b/server/web/captcha/captcha.go similarity index 71% rename from utils/captcha/captcha.go rename to server/web/captcha/captcha.go index 42ac70d371..827c727611 100644 --- a/utils/captcha/captcha.go +++ b/server/web/captcha/captcha.go @@ -19,32 +19,35 @@ // package controllers // // import ( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/cache" -// "github.com/astaxie/beego/utils/captcha" +// +// "github.com/beego/beego/v2" +// "github.com/beego/beego/v2/client/cache" +// "github.com/beego/beego/v2/server/web/captcha" +// // ) // // var cpt *captcha.Captcha // -// func init() { -// // use beego cache system store the captcha data -// store := cache.NewMemoryCache() -// cpt = captcha.NewWithFilter("/captcha/", store) -// } +// func init() { +// // use beego cache system store the captcha data +// store := cache.NewMemoryCache() +// cpt = captcha.NewWithFilter("/captcha/", store) +// } +// +// type MainController struct { +// beego.Controller +// } // -// type MainController struct { -// beego.Controller -// } +// func (this *MainController) Get() { +// this.TplName = "index.tpl" +// } // -// func (this *MainController) Get() { -// this.TplName = "index.tpl" -// } +// func (this *MainController) Post() { +// this.TplName = "index.tpl" // -// func (this *MainController) Post() { -// this.TplName = "index.tpl" +// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) +// } // -// this.Data["Success"] = cpt.VerifyReq(this.Ctx.Request) -// } // ``` // // template usage @@ -52,13 +55,16 @@ // ``` // {{.Success}} //
-// {{create_captcha}} -// +// +// {{create_captcha}} +// +// //
// ``` package captcha import ( + context2 "context" "fmt" "html/template" "net/http" @@ -66,16 +72,13 @@ import ( "strings" "time" - "github.com/astaxie/beego" - "github.com/astaxie/beego/cache" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/core/utils" + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" ) -var ( - defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} -) +var defaultChars = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const ( // default captcha attributes @@ -90,7 +93,7 @@ const ( // Captcha struct type Captcha struct { // beego cache store - store cache.Cache + store Storage // url prefix for captcha image URLPrefix string @@ -137,14 +140,15 @@ func (c *Captcha) Handler(ctx *context.Context) { if len(ctx.Input.Query("reload")) > 0 { chars = c.genRandChars() - if err := c.store.Put(key, chars, c.Expiration); err != nil { + if err := c.store.Put(context2.Background(), key, chars, c.Expiration); err != nil { ctx.Output.SetStatus(500) ctx.WriteString("captcha reload error") logs.Error("Reload Create Captcha Error:", err) return } } else { - if v, ok := c.store.Get(key).([]byte); ok { + val, _ := c.store.Get(context2.Background(), key) + if v, ok := val.([]byte); ok { chars = v } else { ctx.Output.SetStatus(404) @@ -183,7 +187,7 @@ func (c *Captcha) CreateCaptcha() (string, error) { chars := c.genRandChars() // save to store - if err := c.store.Put(c.key(id), chars, c.Expiration); err != nil { + if err := c.store.Put(context2.Background(), c.key(id), chars, c.Expiration); err != nil { return "", err } @@ -198,15 +202,15 @@ func (c *Captcha) VerifyReq(req *http.Request) bool { // Verify direct verify id and challenge string func (c *Captcha) Verify(id string, challenge string) (success bool) { - if len(challenge) == 0 || len(id) == 0 { + if challenge == "" || id == "" { return } var chars []byte key := c.key(id) - - if v, ok := c.store.Get(key).([]byte); ok { + val, _ := c.store.Get(context2.Background(), key) + if v, ok := val.([]byte); ok { chars = v } else { return @@ -214,7 +218,7 @@ func (c *Captcha) Verify(id string, challenge string) (success bool) { defer func() { // finally remove it - c.store.Delete(key) + c.store.Delete(context2.Background(), key) }() if len(chars) != len(challenge) { @@ -231,7 +235,7 @@ func (c *Captcha) Verify(id string, challenge string) (success bool) { } // NewCaptcha create a new captcha.Captcha -func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { +func NewCaptcha(urlPrefix string, store Storage) *Captcha { cpt := &Captcha{} cpt.store = store cpt.FieldIDName = fieldIDName @@ -242,7 +246,7 @@ func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { cpt.StdWidth = stdWidth cpt.StdHeight = stdHeight - if len(urlPrefix) == 0 { + if urlPrefix == "" { urlPrefix = defaultURLPrefix } @@ -257,14 +261,23 @@ func NewCaptcha(urlPrefix string, store cache.Cache) *Captcha { // NewWithFilter create a new captcha.Captcha and auto AddFilter for serve captacha image // and add a template func for output html -func NewWithFilter(urlPrefix string, store cache.Cache) *Captcha { +func NewWithFilter(urlPrefix string, store Storage) *Captcha { cpt := NewCaptcha(urlPrefix, store) // create filter for serve captcha image - beego.InsertFilter(cpt.URLPrefix+"*", beego.BeforeRouter, cpt.Handler) + web.InsertFilter(cpt.URLPrefix+"*", web.BeforeRouter, cpt.Handler) // add to template func map - beego.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) + web.AddFuncMap("create_captcha", cpt.CreateCaptchaHTML) return cpt } + +type Storage interface { + // Get a cached value by key. + Get(ctx context2.Context, key string) (interface{}, error) + // Set a cached value with key and expire time. + Put(ctx context2.Context, key string, val interface{}, timeout time.Duration) error + // Delete cached value by key. + Delete(ctx context2.Context, key string) error +} diff --git a/utils/captcha/image.go b/server/web/captcha/image.go similarity index 100% rename from utils/captcha/image.go rename to server/web/captcha/image.go diff --git a/utils/captcha/image_test.go b/server/web/captcha/image_test.go similarity index 97% rename from utils/captcha/image_test.go rename to server/web/captcha/image_test.go index 5e35b7f779..53e8957584 100644 --- a/utils/captcha/image_test.go +++ b/server/web/captcha/image_test.go @@ -17,7 +17,7 @@ package captcha import ( "testing" - "github.com/astaxie/beego/utils" + "github.com/beego/beego/v2/core/utils" ) type byteCounter struct { diff --git a/utils/captcha/siprng.go b/server/web/captcha/siprng.go similarity index 98% rename from utils/captcha/siprng.go rename to server/web/captcha/siprng.go index 5e256cf93a..4fdf314301 100644 --- a/utils/captcha/siprng.go +++ b/server/web/captcha/siprng.go @@ -27,7 +27,7 @@ type siprng struct { k0, k1, ctr uint64 } -// siphash implements SipHash-2-4, accepting a uint64 as a message. +// siphash implements SipHash-2-4, accepting an uint64 as a message. func siphash(k0, k1, m uint64) uint64 { // Initialization. v0 := k0 ^ 0x736f6d6570736575 diff --git a/utils/captcha/siprng_test.go b/server/web/captcha/siprng_test.go similarity index 100% rename from utils/captcha/siprng_test.go rename to server/web/captcha/siprng_test.go diff --git a/server/web/config.go b/server/web/config.go new file mode 100644 index 0000000000..3cef5eed3d --- /dev/null +++ b/server/web/config.go @@ -0,0 +1,877 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "crypto/tls" + "fmt" + "net/http" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + + "github.com/beego/beego/v2" + "github.com/beego/beego/v2/core/config" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/core/utils" + "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" +) + +// Config is the main struct for BConfig +// TODO after supporting multiple servers, remove common config to somewhere else +type Config struct { + // AppName + // @Description Application's name. You'd better set it because we use it to do some logging and tracing + // @Default beego + AppName string // Application name + // RunMode + // @Description it's the same as environment. In general, we have different run modes. + // For example, the most common case is using dev, test, prod three environments + // when you are developing the application, you should set it as dev + // when you completed coding and want QA to test your code, you should deploy your application to test environment + // and the RunMode should be set as test + // when you completed all tests, you want to deploy it to prod, you should set it to prod + // You should never set RunMode="dev" when you deploy the application to prod + // because Beego will do more things which need Go SDK and other tools when it found out the RunMode="dev" + // @Default dev + RunMode string // Running Mode: dev | prod + + // RouterCaseSensitive + // @Description If it was true, it means that the router is case sensitive. + // For example, when you register a router with pattern "/hello", + // 1. If this is true, and the request URL is "/Hello", it won't match this pattern + // 2. If this is false and the request URL is "/Hello", it will match this pattern + // @Default true + RouterCaseSensitive bool + // RecoverPanic + // @Description if it was true, Beego will try to recover from panic when it serves your http request + // So you should notice that it doesn't mean that Beego will recover all panic cases. + // @Default true + RecoverPanic bool + // CopyRequestBody + // @Description if it's true, Beego will copy the request body. But if the request body's size > MaxMemory, + // Beego will return 413 as http status + // If you are building RESTful API, please set it to true. + // And if you want to read data from request Body multiple times, please set it to true + // In general, if you don't meet any performance issue, you could set it to true + // @Default false + CopyRequestBody bool + // EnableGzip + // @Description If it was true, Beego will try to compress data by using zip algorithm. + // But there are two points: + // 1. Only static resources will be compressed + // 2. Only those static resource which has the extension specified by StaticExtensionsToGzip will be compressed + // @Default false + EnableGzip bool + // EnableErrorsShow + // @Description If it's true, Beego will show error message to page + // it will work with ErrorMaps which allows you register some error handler + // You may want to set it to false when application was deploy to prod environment + // because you may not want to expose your internal error msg to your users + // it's a little bit unsafe + // @Default true + EnableErrorsShow bool + // EnableErrorsRender + // @Description If it's true, it will output the error msg as a page. It's similar to EnableErrorsShow + // And this configure item only work in dev run mode (see RunMode) + // @Default true + EnableErrorsRender bool + // ServerName + // @Description server name. For example, in large scale system, + // you may want to deploy your application to several machines, so that each of them has a server name + // we suggest you'd better set value because Beego use this to output some DEBUG msg, + // or integrated with other tools such as tracing, metrics + // @Default + ServerName string + + // RecoverFunc + // @Description when Beego want to recover from panic, it will use this func as callback + // see RecoverPanic + // @Default defaultRecoverPanic + RecoverFunc func(*context.Context, *Config) + // @Description MaxMemory and MaxUploadSize are used to limit the request body + // if the request is not uploading file, MaxMemory is the max size of request body + // if the request is uploading file, MaxUploadSize is the max size of request body + // if CopyRequestBody is true, this value will be used as the threshold of request body + // see CopyRequestBody + // the default value is 1 << 26 (64MB) + // @Default 67108864 + MaxMemory int64 + // MaxUploadSize + // @Description MaxMemory and MaxUploadSize are used to limit the request body + // if the request is not uploading file, MaxMemory is the max size of request body + // if the request is uploading file, MaxUploadSize is the max size of request body + // the default value is 1 << 30 (1GB) + // @Default 1073741824 + MaxUploadSize int64 + // Listen + // @Description the configuration about socket or http protocol + Listen Listen + // WebConfig + // @Description the configuration about Web + WebConfig WebConfig + // LogConfig + // @Description log configuration + Log LogConfig + // ControllerSuffix + // @Description the suffix of controller for usage of adding auto router + ControllerSuffix string +} + +// Listen holds for http and https related config +type Listen struct { + // Graceful + // @Description means use graceful module to start the server + // @Default false + Graceful bool + // ListenTCP4 + // @Description if it's true, means that Beego only work for TCP4 + // please check net.Listen function + // In general, you should not set it to true + // @Default false + ListenTCP4 bool + // EnableHTTP + // @Description if it's true, Beego will accept HTTP request. + // But if you want to use HTTPS only, please set it to false + // see EnableHTTPS + // @Default true + EnableHTTP bool + // AutoTLS + // @Description If it's true, Beego will use default value to initialize the TLS configure + // But those values could be override if you have custom value. + // see Domains, TLSCacheDir + // @Default false + AutoTLS bool + // EnableHTTPS + // @Description If it's true, Beego will accept HTTPS request. + // Now, you'd better use HTTPS protocol on prod environment to get better security + // In prod, the best option is EnableHTTPS=true and EnableHTTP=false + // see EnableHTTP + // @Default false + EnableHTTPS bool + // EnableMutualHTTPS + // @Description if it's true, Beego will handle requests on incoming mutual TLS connections + // see Server.ListenAndServeMutualTLS + // @Default false + EnableMutualHTTPS bool + // EnableAdmin + // @Description if it's true, Beego will provide admin service. + // You can visit the admin service via browser. + // The default port is 8088 + // see AdminPort + // @Default false + EnableAdmin bool + // EnableFcgi + // @Description + // @Default false + EnableFcgi bool + // EnableStdIo + // @Description EnableStdIo works with EnableFcgi Use FCGI via standard I/O + // @Default false + EnableStdIo bool + // ServerTimeOut + // @Description Beego use this as ReadTimeout and WriteTimeout + // The unit is second. + // see http.Server.ReadTimeout, WriteTimeout + // @Default 0 + ServerTimeOut int64 + // HTTPAddr + // @Description Beego listen to this address when the application start up. + // @Default "" + HTTPAddr string + // HTTPPort + // @Description Beego listen to this port + // you'd better change this value when you deploy to prod environment + // @Default 8080 + HTTPPort int + // Domains + // @Description Beego use this to configure TLS. Those domains are "white list" domain + // @Default [] + Domains []string + // TLSCacheDir + // @Description Beego use this as cache dir to store TLS cert data + // @Default "" + TLSCacheDir string + // HTTPSAddr + // @Description Beego will listen to this address to accept HTTPS request + // see EnableHTTPS + // @Default "" + HTTPSAddr string + // HTTPSPort + // @Description Beego will listen to this port to accept HTTPS request + // @Default 10443 + HTTPSPort int + // HTTPSCertFile + // @Description Beego read this file as cert file + // When you are using HTTPS protocol, please configure it + // see HTTPSKeyFile + // @Default "" + HTTPSCertFile string + // HTTPSKeyFile + // @Description Beego read this file as key file + // When you are using HTTPS protocol, please configure it + // see HTTPSCertFile + // @Default "" + HTTPSKeyFile string + // TrustCaFile + // @Description Beego read this file as CA file + // @Default "" + TrustCaFile string + // AdminAddr + // @Description Beego will listen to this address to provide admin service + // In general, it should be the same with your application address, HTTPAddr or HTTPSAddr + // @Default "" + AdminAddr string + // AdminPort + // @Description Beego will listen to this port to provide admin service + // @Default 8088 + AdminPort int + // @Description Beego use this tls.ClientAuthType to initialize TLS connection + // The default value is tls.RequireAndVerifyClientCert + // @Default 4 + ClientAuth int +} + +// WebConfig holds web related config +type WebConfig struct { + // AutoRender + // @Description If it's true, Beego will render the page based on your template and data + // In general, keep it as true. + // But if you are building RESTFul API and you don't have any page, + // you can set it to false + // @Default true + AutoRender bool + // Deprecated: Beego didn't use it anymore + EnableDocs bool + // EnableXSRF + // @Description If it's true, Beego will help to provide XSRF support + // But you should notice that, now Beego only work for HTTPS protocol with XSRF + // because it's not safe if using HTTP protocol + // And, the cookie storing XSRF token has two more flags HttpOnly and Secure + // It means that you must use HTTPS protocol and you can not read the token from JS script + // This is completed different from Beego 1.x because we got many security reports + // And if you are in dev environment, you could set it to false + // @Default false + EnableXSRF bool + // DirectoryIndex + // @Description When Beego serves static resources request, it will look up the file. + // If the file is directory, Beego will try to find the index.html as the response + // But if the index.html is not exist or it's a directory, + // Beego will return 403 response if DirectoryIndex is **false** + // @Default false + DirectoryIndex bool + // FlashName + // @Description the cookie's name when Beego try to store the flash data into cookie + // @Default BEEGO_FLASH + FlashName string + // FlashSeparator + // @Description When Beego read flash data from request, it uses this as the separator + // @Default BEEGOFLASH + FlashSeparator string + // StaticDir + // @Description Beego uses this as static resources' root directory. + // It means that Beego will try to search static resource from this start point + // It's a map, the key is the path and the value is the directory + // For example, the default value is /static => static, + // which means that when Beego got a request with path /static/xxx + // Beego will try to find the resource from static directory + // @Default /static => static + StaticDir map[string]string + // StaticExtensionsToGzip + // @Description The static resources with those extension will be compressed if EnableGzip is true + // @Default [".css", ".js" ] + StaticExtensionsToGzip []string + // StaticCacheFileSize + // @Description If the size of static resource < StaticCacheFileSize, Beego will try to handle it by itself, + // it means that Beego will compressed the file data (if enable) and cache this file. + // But if the file size > StaticCacheFileSize, Beego just simply delegate the request to http.ServeFile + // the default value is 100KB. + // the max memory size of caching static files is StaticCacheFileSize * StaticCacheFileNum + // see StaticCacheFileNum + // @Default 102400 + StaticCacheFileSize int + // StaticCacheFileNum + // @Description Beego use it to control the memory usage of caching static resource file + // If the caching files > StaticCacheFileNum, Beego use LRU algorithm to remove caching file + // the max memory size of caching static files is StaticCacheFileSize * StaticCacheFileNum + // see StaticCacheFileSize + // @Default 1000 + StaticCacheFileNum int + // TemplateLeft + // @Description Beego use this to render page + // see TemplateRight + // @Default {{ + TemplateLeft string + // TemplateRight + // @Description Beego use this to render page + // see TemplateLeft + // @Default }} + TemplateRight string + // ViewsPath + // @Description The directory of Beego application storing template + // @Default views + ViewsPath string + // CommentRouterPath + // @Description Beego scans this directory and its sub directory to generate router + // Beego only scans this directory when it's in dev environment + // @Default controllers + CommentRouterPath string + // XSRFKey + // @Description the name of cookie storing XSRF token + // see EnableXSRF + // @Default beegoxsrf + XSRFKey string + // XSRFExpire + // @Description the expiration time of XSRF token cookie + // second + // @Default 0 + XSRFExpire int + // @Description session related config + Session SessionConfig +} + +// SessionConfig holds session related config +type SessionConfig struct { + // SessionOn + // @Description if it's true, Beego will auto manage session + // @Default false + SessionOn bool + // SessionAutoSetCookie + // @Description if it's true, Beego will put the session token into cookie too + // @Default true + SessionAutoSetCookie bool + // SessionDisableHTTPOnly + // @Description used to allow for cross domain cookies/javascript cookies + // In general, you should not set it to true unless you understand the risk + // @Default false + SessionDisableHTTPOnly bool + // SessionEnableSidInHTTPHeader + // @Description enable store/get the sessionId into/from http headers + // @Default false + SessionEnableSidInHTTPHeader bool + // SessionEnableSidInURLQuery + // @Description enable get the sessionId from Url Query params + // @Default false + SessionEnableSidInURLQuery bool + // SessionProvider + // @Description session provider's name. + // You should confirm that this provider has been register via session.Register method + // the default value is memory. This is not suitable for distributed system + // @Default memory + SessionProvider string + // SessionName + // @Description If SessionAutoSetCookie is true, we use this value as the cookie's name + // @Default beegosessionID + SessionName string + // SessionGCMaxLifetime + // @Description Beego will GC session to clean useless session. + // unit: second + // @Default 3600 + SessionGCMaxLifetime int64 + // SessionProviderConfig + // @Description the config of session provider + // see SessionProvider + // you should read the document of session provider to learn how to set this value + // @Default "" + SessionProviderConfig string + // SessionCookieLifeTime + // @Description If SessionAutoSetCookie is true, + // we use this value as the expiration time and max age of the cookie + // unit second + // @Default 0 + SessionCookieLifeTime int + // SessionDomain + // @Description If SessionAutoSetCookie is true, we use this value as the cookie's domain + // @Default "" + SessionDomain string + // SessionNameInHTTPHeader + // @Description if SessionEnableSidInHTTPHeader is true, this value will be used as the http header + // @Default Beegosessionid + SessionNameInHTTPHeader string + // SessionCookieSameSite + // @Description If SessionAutoSetCookie is true, we use this value as the cookie's same site policy + // the default value is http.SameSiteDefaultMode + // @Default 1 + SessionCookieSameSite http.SameSite + + // SessionIDPrefix + // @Description session id's prefix + // @Default "" + SessionIDPrefix string +} + +// LogConfig holds Log related config +type LogConfig struct { + // AccessLogs + // @Description If it's true, Beego will log the HTTP request info + // @Default false + AccessLogs bool + // EnableStaticLogs + // @Description log static files requests + // @Default false + EnableStaticLogs bool + // FileLineNum + // @Description if it's true, it will log the line number + // @Default true + FileLineNum bool + // AccessLogsFormat + // @Description access log format: JSON_FORMAT, APACHE_FORMAT or empty string + // @Default APACHE_FORMAT + AccessLogsFormat string + // Outputs + // @Description the destination of access log + // the key is log adapter and the value is adapter's configure + // @Default "console" => "" + Outputs map[string]string // Store Adaptor : config +} + +var ( + // BConfig is the default config for Application + BConfig *Config + // AppConfig is the instance of Config, store the config information from file + AppConfig *beegoAppConfig + // AppPath is the absolute path to the app + AppPath string + // GlobalSessions is the instance for the session manager + GlobalSessions *session.Manager + + // appConfigPath is the path to the config files + appConfigPath string + // appConfigProvider is the provider for the config, default is ini + appConfigProvider = "ini" + // WorkPath is the absolute path to project root directory + WorkPath string +) + +func init() { + BConfig = newBConfig() + var err error + if AppPath, err = filepath.Abs(filepath.Dir(os.Args[0])); err != nil { + panic(err) + } + WorkPath, err = os.Getwd() + if err != nil { + panic(err) + } + filename := "app.conf" + if os.Getenv("BEEGO_RUNMODE") != "" { + filename = os.Getenv("BEEGO_RUNMODE") + ".app.conf" + } + appConfigPath = filepath.Join(WorkPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + appConfigPath = filepath.Join(AppPath, "conf", filename) + if !utils.FileExists(appConfigPath) { + AppConfig = &beegoAppConfig{innerConfig: config.NewFakeConfig()} + return + } + } + if err = parseConfig(appConfigPath); err != nil { + panic(err) + } +} + +func defaultRecoverPanic(ctx *context.Context, cfg *Config) { + if err := recover(); err != nil { + if err == ErrAbort { + return + } + if !cfg.RecoverPanic { + panic(err) + } + if cfg.EnableErrorsShow { + if _, ok := ErrorMaps[fmt.Sprint(err)]; ok { + exception(fmt.Sprint(err), ctx) + return + } + } + var stack string + logs.Critical("the request url is ", ctx.Input.URL()) + logs.Critical("Handler crashed with error", err) + for i := 1; ; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + logs.Critical(fmt.Sprintf("%s:%d", file, line)) + stack += fmt.Sprintf("%s:%d\n", file, line) + } + + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) + } else { + ctx.ResponseWriter.WriteHeader(500) + } + + if cfg.RunMode == DEV && cfg.EnableErrorsRender { + showErr(err, ctx, stack) + } + } +} + +func newBConfig() *Config { + res := &Config{ + AppName: "beego", + RunMode: PROD, + RouterCaseSensitive: true, + ServerName: "beegoServer:" + beego.VERSION, + RecoverPanic: true, + + CopyRequestBody: false, + EnableGzip: false, + MaxMemory: 1 << 26, // 64MB + MaxUploadSize: 1 << 30, // 1GB + EnableErrorsShow: true, + EnableErrorsRender: true, + Listen: Listen{ + Graceful: false, + ServerTimeOut: 0, + ListenTCP4: false, + EnableHTTP: true, + AutoTLS: false, + Domains: []string{}, + TLSCacheDir: ".", + HTTPAddr: "", + HTTPPort: 8080, + EnableHTTPS: false, + HTTPSAddr: "", + HTTPSPort: 10443, + HTTPSCertFile: "", + HTTPSKeyFile: "", + EnableAdmin: false, + AdminAddr: "", + AdminPort: 8088, + EnableFcgi: false, + EnableStdIo: false, + ClientAuth: int(tls.RequireAndVerifyClientCert), + }, + WebConfig: WebConfig{ + AutoRender: true, + EnableDocs: false, + FlashName: "BEEGO_FLASH", + FlashSeparator: "BEEGOFLASH", + DirectoryIndex: false, + StaticDir: map[string]string{"/static": "static"}, + StaticExtensionsToGzip: []string{".css", ".js"}, + StaticCacheFileSize: 1024 * 100, + StaticCacheFileNum: 1000, + TemplateLeft: "{{", + TemplateRight: "}}", + ViewsPath: "views", + CommentRouterPath: "controllers", + EnableXSRF: false, + XSRFKey: "beegoxsrf", + XSRFExpire: 0, + Session: SessionConfig{ + SessionOn: false, + SessionProvider: "memory", + SessionName: "beegosessionID", + SessionGCMaxLifetime: 3600, + SessionProviderConfig: "", + SessionDisableHTTPOnly: false, + SessionCookieLifeTime: 0, // set cookie default is the browser life + SessionAutoSetCookie: true, + SessionDomain: "", + SessionEnableSidInHTTPHeader: false, // enable store/get the sessionId into/from http headers + SessionNameInHTTPHeader: "Beegosessionid", + SessionEnableSidInURLQuery: false, // enable get the sessionId from Url Query params + SessionCookieSameSite: http.SameSiteDefaultMode, + }, + }, + Log: LogConfig{ + AccessLogs: false, + EnableStaticLogs: false, + AccessLogsFormat: "APACHE_FORMAT", + FileLineNum: true, + Outputs: map[string]string{"console": ""}, + }, + + ControllerSuffix: "Controller", + } + + res.RecoverFunc = defaultRecoverPanic + return res +} + +// now only support ini, next will support json. +func parseConfig(appConfigPath string) (err error) { + AppConfig, err = newAppConfig(appConfigProvider, appConfigPath) + if err != nil { + return err + } + return assignConfig(AppConfig) +} + +// assignConfig is tricky. +// For 1.x, it use assignSingleConfig to parse the file +// but for 2.x, we use Unmarshaler method +func assignConfig(ac config.Configer) error { + parseConfigForV1(ac) + + err := ac.Unmarshaler("", BConfig) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, fmt.Sprintf("Unmarshaler config file to BConfig failed. "+ + "And if you are working on v1.x config file, please ignore this, err: %s", err)) + return err + } + + // init log + logs.Reset() + for adaptor, cfg := range BConfig.Log.Outputs { + err := logs.SetLogger(adaptor, cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "%s with the config %q got err:%s\n", adaptor, cfg, err.Error()) + return err + } + } + logs.SetLogFuncCall(BConfig.Log.FileLineNum) + return nil +} + +func parseConfigForV1(ac config.Configer) { + for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} { + assignSingleConfig(i, ac) + } + + // set the run mode first + if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { + BConfig.RunMode = envRunMode + } else if runMode, err := ac.String("RunMode"); runMode != "" && err == nil { + BConfig.RunMode = runMode + } + + if sd, err := ac.String("StaticDir"); sd != "" && err == nil { + BConfig.WebConfig.StaticDir = map[string]string{} + sds := strings.Fields(sd) + for _, v := range sds { + if url2fsmap := strings.SplitN(v, ":", 2); len(url2fsmap) == 2 { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[1] + } else { + BConfig.WebConfig.StaticDir["/"+strings.Trim(url2fsmap[0], "/")] = url2fsmap[0] + } + } + } + + if sgz, err := ac.String("StaticExtensionsToGzip"); sgz != "" && err == nil { + extensions := strings.Split(sgz, ",") + fileExts := []string{} + for _, ext := range extensions { + ext = strings.TrimSpace(ext) + if ext == "" { + continue + } + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + fileExts = append(fileExts, ext) + } + if len(fileExts) > 0 { + BConfig.WebConfig.StaticExtensionsToGzip = fileExts + } + } + + if sfs, err := ac.Int("StaticCacheFileSize"); err == nil { + BConfig.WebConfig.StaticCacheFileSize = sfs + } + + if sfn, err := ac.Int("StaticCacheFileNum"); err == nil { + BConfig.WebConfig.StaticCacheFileNum = sfn + } + + if lo, err := ac.String("LogOutputs"); lo != "" && err == nil { + // if lo is not nil or empty + // means user has set his own LogOutputs + // clear the default setting to BConfig.Log.Outputs + BConfig.Log.Outputs = make(map[string]string) + los := strings.Split(lo, ";") + for _, v := range los { + if logType2Config := strings.SplitN(v, ",", 2); len(logType2Config) == 2 { + BConfig.Log.Outputs[logType2Config[0]] = logType2Config[1] + } else { + continue + } + } + } +} + +func assignSingleConfig(p interface{}, ac config.Configer) { + pt := reflect.TypeOf(p) + if pt.Kind() != reflect.Ptr { + return + } + pt = pt.Elem() + if pt.Kind() != reflect.Struct { + return + } + pv := reflect.ValueOf(p).Elem() + + for i := 0; i < pt.NumField(); i++ { + pf := pv.Field(i) + if !pf.CanSet() { + continue + } + name := pt.Field(i).Name + switch pf.Kind() { + case reflect.String: + pf.SetString(ac.DefaultString(name, pf.String())) + case reflect.Int, reflect.Int64: + pf.SetInt(ac.DefaultInt64(name, pf.Int())) + case reflect.Bool: + pf.SetBool(ac.DefaultBool(name, pf.Bool())) + case reflect.Struct: + default: + // do nothing here + } + } +} + +// LoadAppConfig allow developer to apply a config file +func LoadAppConfig(adapterName, configPath string) error { + absConfigPath, err := filepath.Abs(configPath) + if err != nil { + return err + } + + if !utils.FileExists(absConfigPath) { + return fmt.Errorf("the target config file: %s don't exist", configPath) + } + + appConfigPath = absConfigPath + appConfigProvider = adapterName + + return parseConfig(appConfigPath) +} + +type beegoAppConfig struct { + config.BaseConfiger + innerConfig config.Configer +} + +func newAppConfig(appConfigProvider, appConfigPath string) (*beegoAppConfig, error) { + ac, err := config.NewConfig(appConfigProvider, appConfigPath) + if err != nil { + return nil, err + } + return &beegoAppConfig{innerConfig: ac}, nil +} + +func (b *beegoAppConfig) Unmarshaler(prefix string, obj interface{}, opt ...config.DecodeOption) error { + return b.innerConfig.Unmarshaler(prefix, obj, opt...) +} + +func (b *beegoAppConfig) Set(key, val string) error { + if err := b.innerConfig.Set(BConfig.RunMode+"::"+key, val); err != nil { + return b.innerConfig.Set(key, val) + } + return nil +} + +func (b *beegoAppConfig) String(key string) (string, error) { + if v, err := b.innerConfig.String(BConfig.RunMode + "::" + key); v != "" && err == nil { + return v, nil + } + return b.innerConfig.String(key) +} + +func (b *beegoAppConfig) Strings(key string) ([]string, error) { + if v, err := b.innerConfig.Strings(BConfig.RunMode + "::" + key); len(v) > 0 && err == nil { + return v, nil + } + return b.innerConfig.Strings(key) +} + +func (b *beegoAppConfig) Int(key string) (int, error) { + if v, err := b.innerConfig.Int(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int(key) +} + +func (b *beegoAppConfig) Int64(key string) (int64, error) { + if v, err := b.innerConfig.Int64(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Int64(key) +} + +func (b *beegoAppConfig) Bool(key string) (bool, error) { + if v, err := b.innerConfig.Bool(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Bool(key) +} + +func (b *beegoAppConfig) Float(key string) (float64, error) { + if v, err := b.innerConfig.Float(BConfig.RunMode + "::" + key); err == nil { + return v, nil + } + return b.innerConfig.Float(key) +} + +func (b *beegoAppConfig) DefaultString(key string, defaultVal string) string { + if v, err := b.String(key); v != "" && err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultStrings(key string, defaultVal []string) []string { + if v, err := b.Strings(key); len(v) != 0 && err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt(key string, defaultVal int) int { + if v, err := b.Int(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultInt64(key string, defaultVal int64) int64 { + if v, err := b.Int64(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultBool(key string, defaultVal bool) bool { + if v, err := b.Bool(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DefaultFloat(key string, defaultVal float64) float64 { + if v, err := b.Float(key); err == nil { + return v + } + return defaultVal +} + +func (b *beegoAppConfig) DIY(key string) (interface{}, error) { + return b.innerConfig.DIY(key) +} + +func (b *beegoAppConfig) GetSection(section string) (map[string]string, error) { + return b.innerConfig.GetSection(section) +} + +func (b *beegoAppConfig) SaveConfigFile(filename string) error { + return b.innerConfig.SaveConfigFile(filename) +} diff --git a/config_test.go b/server/web/config_test.go similarity index 85% rename from config_test.go rename to server/web/config_test.go index 53411b0168..2be4fd73b5 100644 --- a/config_test.go +++ b/server/web/config_test.go @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "encoding/json" "reflect" "testing" - "github.com/astaxie/beego/config" + beeJson "github.com/beego/beego/v2/core/config/json" ) func TestDefaults(t *testing.T) { @@ -30,12 +30,20 @@ func TestDefaults(t *testing.T) { if BConfig.WebConfig.FlashSeparator != "BEEGOFLASH" { t.Errorf("FlashName was not set to default.") } + + if BConfig.ControllerSuffix != "Controller" { + t.Errorf("ControllerSuffix was not set to default.") + } +} + +func TestLoadAppConfig(t *testing.T) { + println(1 << 30) } func TestAssignConfig_01(t *testing.T) { _BConfig := &Config{} _BConfig.AppName = "beego_test" - jcf := &config.JSONConfig{} + jcf := &beeJson.JSONConfig{} ac, _ := jcf.ParseData([]byte(`{"AppName":"beego_json"}`)) assignSingleConfig(_BConfig, ac) if _BConfig.AppName != "beego_json" { @@ -73,7 +81,7 @@ func TestAssignConfig_02(t *testing.T) { configMap["SessionProviderConfig"] = "file" configMap["FileLineNum"] = true - jcf := &config.JSONConfig{} + jcf := &beeJson.JSONConfig{} bs, _ = json.Marshal(configMap) ac, _ := jcf.ParseData(bs) @@ -105,16 +113,17 @@ func TestAssignConfig_02(t *testing.T) { t.Log(_BConfig.Log.FileLineNum) t.FailNow() } - } func TestAssignConfig_03(t *testing.T) { - jcf := &config.JSONConfig{} + jcf := &beeJson.JSONConfig{} ac, _ := jcf.ParseData([]byte(`{"AppName":"beego"}`)) ac.Set("AppName", "test_app") ac.Set("RunMode", "online") ac.Set("StaticDir", "download:down download2:down2") ac.Set("StaticExtensionsToGzip", ".css,.js,.html,.jpg,.png") + ac.Set("StaticCacheFileSize", "87456") + ac.Set("StaticCacheFileNum", "1254") assignConfig(ac) t.Logf("%#v", BConfig) @@ -132,6 +141,12 @@ func TestAssignConfig_03(t *testing.T) { if BConfig.WebConfig.StaticDir["/download2"] != "down2" { t.FailNow() } + if BConfig.WebConfig.StaticCacheFileSize != 87456 { + t.FailNow() + } + if BConfig.WebConfig.StaticCacheFileNum != 1254 { + t.FailNow() + } if len(BConfig.WebConfig.StaticExtensionsToGzip) != 5 { t.FailNow() } diff --git a/context/acceptencoder.go b/server/web/context/acceptencoder.go similarity index 81% rename from context/acceptencoder.go rename to server/web/context/acceptencoder.go index b4e2492c0f..4181bfee96 100644 --- a/context/acceptencoder.go +++ b/server/web/context/acceptencoder.go @@ -28,18 +28,18 @@ import ( ) var ( - //Default size==20B same as nginx + // Default size==20B same as nginx defaultGzipMinLength = 20 - //Content will only be compressed if content length is either unknown or greater than gzipMinLength. + // Content will only be compressed if content length is either unknown or greater than gzipMinLength. gzipMinLength = defaultGzipMinLength - //The compression level used for deflate compression. (0-9). + // Compression level used for deflate compression. (0-9). gzipCompressLevel int - //List of HTTP methods to compress. If not set, only GET requests are compressed. + // List of HTTP methods to compress. If not set, only GET requests are compressed. includedMethods map[string]bool getMethodOnly bool ) -// InitGzip init the gzipcompress +// InitGzip initializes the gzipcompress func InitGzip(minLength, compressLevel int, methods []string) { if minLength >= 0 { gzipMinLength = minLength @@ -65,7 +65,7 @@ type nopResetWriter struct { } func (n nopResetWriter) Reset(w io.Writer) { - //do nothing + // do nothing } type acceptEncoder struct { @@ -98,9 +98,9 @@ func (ac acceptEncoder) put(wr resetWriter, level int) { } wr.Reset(nil) - //notice - //compressionLevel==BestCompression DOES NOT MATTER - //sync.Pool will not memory leak + // notice + // compressionLevel==BestCompression DOES NOT MATTER + // sync.Pool will not memory leak switch level { case gzipCompressLevel: @@ -119,10 +119,10 @@ var ( bestCompressionPool: &sync.Pool{New: func() interface{} { wr, _ := gzip.NewWriterLevel(nil, flate.BestCompression); return wr }}, } - //according to the sec :http://tools.ietf.org/html/rfc2616#section-3.5 ,the deflate compress in http is zlib indeed - //deflate - //The "zlib" format defined in RFC 1950 [31] in combination with - //the "deflate" compression mechanism described in RFC 1951 [29]. + // According to: http://tools.ietf.org/html/rfc2616#section-3.5 the deflate compress in http is zlib indeed + // deflate + // The "zlib" format defined in RFC 1950 [31] in combination with + // the "deflate" compression mechanism described in RFC 1951 [29]. deflateCompressEncoder = acceptEncoder{ name: "deflate", levelEncode: func(level int) resetWriter { wr, _ := zlib.NewWriterLevel(nil, level); return wr }, @@ -131,21 +131,19 @@ var ( } ) -var ( - encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore - "gzip": gzipCompressEncoder, - "deflate": deflateCompressEncoder, - "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip - "identity": noneCompressEncoder, // identity means none-compress - } -) +var encoderMap = map[string]acceptEncoder{ // all the other compress methods will ignore + "gzip": gzipCompressEncoder, + "deflate": deflateCompressEncoder, + "*": gzipCompressEncoder, // * means any compress will accept,we prefer gzip + "identity": noneCompressEncoder, // identity means none-compress +} // WriteFile reads from file and writes to writer by the specific encoding(gzip/deflate) func WriteFile(encoding string, writer io.Writer, file *os.File) (bool, string, error) { return writeLevel(encoding, writer, file, flate.BestCompression) } -// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) +// WriteBody reads writes content to writer by the specific encoding(gzip/deflate) func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, error) { if encoding == "" || len(content) < gzipMinLength { _, err := writer.Write(content) @@ -154,12 +152,12 @@ func WriteBody(encoding string, writer io.Writer, content []byte) (bool, string, return writeLevel(encoding, writer, bytes.NewReader(content), gzipCompressLevel) } -// writeLevel reads from reader,writes to writer by specific encoding and compress level -// the compress level is defined by deflate package +// writeLevel reads from reader and writes to writer by specific encoding and compress level. +// The compress level is defined by deflate package func writeLevel(encoding string, writer io.Writer, reader io.Reader, level int) (bool, string, error) { var outputWriter resetWriter var err error - var ce = noneCompressEncoder + ce := noneCompressEncoder if cf, ok := encoderMap[encoding]; ok { ce = cf diff --git a/context/acceptencoder_test.go b/server/web/context/acceptencoder_test.go similarity index 100% rename from context/acceptencoder_test.go rename to server/web/context/acceptencoder_test.go diff --git a/server/web/context/context.go b/server/web/context/context.go new file mode 100644 index 0000000000..1c209c3147 --- /dev/null +++ b/server/web/context/context.go @@ -0,0 +1,421 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provide the context utils +// Usage: +// +// import "github.com/beego/beego/v2/server/web/context" +// +// ctx := context.Context{Request:req,ResponseWriter:rw} +package context + +import ( + "bufio" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "reflect" + "strconv" + "strings" + "time" + + "google.golang.org/protobuf/proto" + "gopkg.in/yaml.v3" + + "github.com/beego/beego/v2/core/utils" + "github.com/beego/beego/v2/server/web/session" +) + +// Commonly used mime-types +const ( + ApplicationJSON = "application/json" + ApplicationXML = "application/xml" + ApplicationForm = "application/x-www-form-urlencoded" + ApplicationProto = "application/x-protobuf" + ApplicationYAML = "application/x-yaml" + TextXML = "text/xml" + + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" + formatDateTimeT = "2006-01-02T15:04:05" +) + +// NewContext return the Context with Input and Output +func NewContext() *Context { + return &Context{ + Input: NewInput(), + Output: NewOutput(), + } +} + +// Context Http request context struct including BeegoInput, BeegoOutput, http.Request and http.ResponseWriter. +// BeegoInput and BeegoOutput provides an api to operate request and response more easily. +type Context struct { + Input *BeegoInput + Output *BeegoOutput + Request *http.Request + ResponseWriter *Response + _xsrfToken string +} + +func (ctx *Context) EventStreamResp() chan<- []byte { + eventCh := make(chan []byte) + go func() { + for { + select { + case eventData, ok := <-eventCh: + if !ok { + return + } + sendEvent(ctx, eventData) + case <-ctx.Request.Context().Done(): + close(eventCh) + return + } + } + }() + return eventCh +} + +func sendEvent(ctx *Context, data []byte) { + _, _ = ctx.ResponseWriter.Write(data) + ctx.ResponseWriter.Flush() +} + +func (ctx *Context) Bind(obj interface{}) error { + ct, exist := ctx.Request.Header["Content-Type"] + if !exist || len(ct) == 0 { + return ctx.BindJSON(obj) + } + i, l := 0, len(ct[0]) + for i < l && ct[0][i] != ';' { + i++ + } + switch ct[0][0:i] { + case ApplicationJSON: + return ctx.BindJSON(obj) + case ApplicationXML, TextXML: + return ctx.BindXML(obj) + case ApplicationForm: + return ctx.BindForm(obj) + case ApplicationProto: + return ctx.BindProtobuf(obj.(proto.Message)) + case ApplicationYAML: + return ctx.BindYAML(obj) + default: + return errors.New("Unsupported Content-Type:" + ct[0]) + } +} + +// Resp sends response based on the Accept Header +// By default response will be in JSON +func (ctx *Context) Resp(data interface{}) error { + accept := ctx.Input.Header("Accept") + switch accept { + case ApplicationYAML: + return ctx.YamlResp(data) + case ApplicationXML, TextXML: + return ctx.XMLResp(data) + case ApplicationProto: + return ctx.ProtoResp(data.(proto.Message)) + default: + return ctx.JSONResp(data) + } +} + +func (ctx *Context) JSONResp(data interface{}) error { + return ctx.Output.JSON(data, false, false) +} + +func (ctx *Context) XMLResp(data interface{}) error { + return ctx.Output.XML(data, false) +} + +func (ctx *Context) YamlResp(data interface{}) error { + return ctx.Output.YAML(data) +} + +func (ctx *Context) ProtoResp(data proto.Message) error { + return ctx.Output.Proto(data) +} + +// BindYAML only read data from http request body +func (ctx *Context) BindYAML(obj interface{}) error { + return yaml.Unmarshal(ctx.Input.RequestBody, obj) +} + +// BindForm will parse form values to struct via tag. +func (ctx *Context) BindForm(obj interface{}) error { + err := ctx.Request.ParseForm() + if err != nil { + return err + } + return ParseForm(ctx.Request.Form, obj) +} + +// BindJSON only read data from http request body +func (ctx *Context) BindJSON(obj interface{}) error { + return json.Unmarshal(ctx.Input.RequestBody, obj) +} + +// BindProtobuf only read data from http request body +func (ctx *Context) BindProtobuf(obj proto.Message) error { + return proto.Unmarshal(ctx.Input.RequestBody, obj) +} + +// BindXML only read data from http request body +func (ctx *Context) BindXML(obj interface{}) error { + return xml.Unmarshal(ctx.Input.RequestBody, obj) +} + +// ParseForm will parse form values to struct via tag. +func ParseForm(form url.Values, obj interface{}) error { + objT := reflect.TypeOf(obj) + objV := reflect.ValueOf(obj) + if !isStructPtr(objT) { + return fmt.Errorf("%v must be a struct pointer", obj) + } + objT = objT.Elem() + objV = objV.Elem() + return parseFormToStruct(form, objT, objV) +} + +func isStructPtr(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + +// Reset initializes Context, BeegoInput and BeegoOutput +func (ctx *Context) Reset(rw http.ResponseWriter, r *http.Request) { + ctx.Request = r + if ctx.ResponseWriter == nil { + ctx.ResponseWriter = &Response{} + } + ctx.ResponseWriter.reset(rw) + ctx.Input.Reset(ctx) + ctx.Output.Reset(ctx) + ctx._xsrfToken = "" +} + +// Redirect redirects to localurl with http header status code. +func (ctx *Context) Redirect(status int, localurl string) { + http.Redirect(ctx.ResponseWriter, ctx.Request, localurl, status) +} + +// Abort stops the request. +// If beego.ErrorMaps exists, panic body. +func (ctx *Context) Abort(status int, body string) { + ctx.Output.SetStatus(status) + panic(body) +} + +// WriteString writes a string to response body. +func (ctx *Context) WriteString(content string) { + _, _ = ctx.ResponseWriter.Write([]byte(content)) +} + +// GetCookie gets a cookie from a request for a given key. +// (Alias of BeegoInput.Cookie) +func (ctx *Context) GetCookie(key string) string { + return ctx.Input.Cookie(key) +} + +// SetCookie sets a cookie for a response. +// (Alias of BeegoOutput.Cookie) +func (ctx *Context) SetCookie(name string, value string, others ...interface{}) { + ctx.Output.Cookie(name, value, others...) +} + +// GetSecureCookie gets a secure cookie from a request for a given key. +func (ctx *Context) GetSecureCookie(Secret, key string) (string, bool) { + val := ctx.Input.Cookie(key) + if val == "" { + return "", false + } + + parts := strings.SplitN(val, "|", 3) + + if len(parts) != 3 { + return "", false + } + + vs := parts[0] + timestamp := parts[1] + sig := parts[2] + + h := hmac.New(sha256.New, []byte(Secret)) + _, err := fmt.Fprintf(h, "%s%s", vs, timestamp) + if err != nil { + return "", false + } + + if fmt.Sprintf("%02x", h.Sum(nil)) != sig { + return "", false + } + res, _ := base64.URLEncoding.DecodeString(vs) + return string(res), true +} + +// SetSecureCookie sets a secure cookie for a response. +func (ctx *Context) SetSecureCookie(Secret, name, value string, others ...interface{}) { + vs := base64.URLEncoding.EncodeToString([]byte(value)) + timestamp := strconv.FormatInt(time.Now().UnixNano(), 10) + h := hmac.New(sha256.New, []byte(Secret)) + _, _ = fmt.Fprintf(h, "%s%s", vs, timestamp) + sig := fmt.Sprintf("%02x", h.Sum(nil)) + cookie := strings.Join([]string{vs, timestamp, sig}, "|") + ctx.Output.Cookie(name, cookie, others...) +} + +// XSRFToken creates and returns an xsrf token string +func (ctx *Context) XSRFToken(key string, expire int64) string { + if ctx._xsrfToken == "" { + token, ok := ctx.GetSecureCookie(key, "_xsrf") + if !ok { + token = string(utils.RandomCreateBytes(32)) + // TODO make it configurable + ctx.SetSecureCookie(key, "_xsrf", token, expire, "/", "", true, true) + } + ctx._xsrfToken = token + } + return ctx._xsrfToken +} + +// CheckXSRFCookie checks if the XSRF token in this request is valid or not. +// The token can be provided in the request header in the form "X-Xsrftoken" or "X-CsrfToken" +// or in form field value named as "_xsrf". +func (ctx *Context) CheckXSRFCookie() bool { + token := ctx.Input.Query("_xsrf") + if token == "" { + token = ctx.Request.Header.Get("X-Xsrftoken") + } + if token == "" { + token = ctx.Request.Header.Get("X-Csrftoken") + } + if token == "" { + ctx.Abort(422, "422") + return false + } + if ctx._xsrfToken != token { + ctx.Abort(417, "417") + return false + } + return true +} + +// RenderMethodResult renders the return value of a controller method to the output +func (ctx *Context) RenderMethodResult(result interface{}) { + if result != nil { + renderer, ok := result.(Renderer) + if !ok { + err, ok := result.(error) + if ok { + renderer = errorRenderer(err) + } else { + renderer = jsonRenderer(result) + } + } + renderer.Render(ctx) + } +} + +// Session return session store of this context of request +func (ctx *Context) Session() (store session.Store, err error) { + if ctx.Input != nil { + if ctx.Input.CruSession != nil { + store = ctx.Input.CruSession + return + } else { + err = errors.New(`no valid session store(please initialize session)`) + return + } + } else { + err = errors.New(`no valid input`) + return + } +} + +// Response is a wrapper for the http.ResponseWriter +// Started: if true, response was already written to so the other handler will not be executed +type Response struct { + http.ResponseWriter + Started bool + Status int + Elapsed time.Duration +} + +func (r *Response) reset(rw http.ResponseWriter) { + r.ResponseWriter = rw + r.Status = 0 + r.Started = false +} + +// Write writes the data to the connection as part of a HTTP reply, +// and sets `Started` to true. +// Started: if true, the response was already sent +func (r *Response) Write(p []byte) (int, error) { + r.Started = true + return r.ResponseWriter.Write(p) +} + +// WriteHeader sends a HTTP response header with status code, +// and sets `Started` to true. +func (r *Response) WriteHeader(code int) { + if r.Status > 0 { + // prevent multiple response.WriteHeader calls + return + } + r.Status = code + r.Started = true + r.ResponseWriter.WriteHeader(code) +} + +// Hijack hijacker for http +func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := r.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("webserver doesn't support hijacking") + } + return hj.Hijack() +} + +// Flush http.Flusher +func (r *Response) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// CloseNotify http.CloseNotifier +func (r *Response) CloseNotify() <-chan bool { + if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return nil +} + +// Pusher http.Pusher +func (r *Response) Pusher() (pusher http.Pusher) { + if pusher, ok := r.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} diff --git a/server/web/context/context_test.go b/server/web/context/context_test.go new file mode 100644 index 0000000000..69b3b6e8a1 --- /dev/null +++ b/server/web/context/context_test.go @@ -0,0 +1,232 @@ +// Copyright 2016 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "github.com/beego/beego/v2/server/web/session" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" +) + +// Test concurrency safety +func TestEventStreamResp_Concurrency(t *testing.T) { + req, err := http.NewRequest("GET", "/events", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + resp := httptest.NewRecorder() + beegoResp := &Response{ResponseWriter: resp} + + ctxInstance := &Context{ + Request: req, + ResponseWriter: beegoResp, + } + + eventCh := ctxInstance.EventStreamResp() + + var wg sync.WaitGroup + messages := []string{"msg1", "msg2", "msg3", "msg4", "msg5"} + + for _, msg := range messages { + wg.Add(1) + go func(m string) { + defer wg.Done() + eventCh <- []byte(m) + }(msg) + } + + wg.Wait() + time.Sleep(100 * time.Millisecond) + + received := resp.Body.String() + for _, msg := range messages { + if !strings.Contains(received, msg) { + t.Errorf("Response missing message: %s", msg) + } + } +} + +func TestEventStreamResp_ChannelScenarios(t *testing.T) { + tests := []struct { + name string + setup func(*Context) chan<- []byte + verify func(*testing.T, *httptest.ResponseRecorder, chan<- []byte) bool + expectPanic bool + }{ + { + name: "Scenario 1: Send multiple data, channel auto-closed", + setup: func(ctx *Context) chan<- []byte { + return ctx.EventStreamResp() + }, + verify: func(t *testing.T, resp *httptest.ResponseRecorder, eventCh chan<- []byte) bool { + messages := []string{"msg1", "msg2", "msg3", "msg4"} + + for _, msg := range messages { + eventCh <- []byte(msg) + } + + time.Sleep(100 * time.Millisecond) + + received := resp.Body.String() + for _, msg := range messages { + if !strings.Contains(received, msg) { + t.Errorf("Expected message %q not found in response", msg) + } + } + var isPanic bool + return isPanic + }, + expectPanic: false, + }, + { + name: "Scenario 2: Send data then manually close channel", + setup: func(ctx *Context) chan<- []byte { + return ctx.EventStreamResp() + }, + verify: func(t *testing.T, resp *httptest.ResponseRecorder, eventCh chan<- []byte) (isPanic bool) { + eventCh <- []byte("first message") + time.Sleep(10 * time.Millisecond) + if !strings.Contains(resp.Body.String(), "first message") { + t.Error("First message not received") + } + close(eventCh) + func() { + defer func() { + if r := recover(); r != nil { + isPanic = true + } + }() + eventCh <- []byte("should panic") + }() + return + }, + expectPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "/events", nil) + resp := httptest.NewRecorder() + + ctxInstance := &Context{ + Request: req, + ResponseWriter: &Response{ResponseWriter: resp}, + } + + eventCh := tt.setup(ctxInstance) + + resPanicStatus := tt.verify(t, resp, eventCh) + assert.Equal(t, tt.expectPanic, resPanicStatus) + + }) + } +} + +func TestXsrfReset_01(t *testing.T) { + r := &http.Request{} + c := NewContext() + c.Request = r + c.ResponseWriter = &Response{} + c.ResponseWriter.reset(httptest.NewRecorder()) + c.Output.Reset(c) + c.Input.Reset(c) + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + token := c._xsrfToken + c.Reset(&Response{ResponseWriter: httptest.NewRecorder()}, r) + if c._xsrfToken != "" { + t.FailNow() + } + c.XSRFToken("key", 16) + if c._xsrfToken == "" { + t.FailNow() + } + if token == c._xsrfToken { + t.FailNow() + } +} + +func TestContext_Session(t *testing.T) { + c := NewContext() + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} + +func TestContext_Session1(t *testing.T) { + c := Context{} + if store, err := c.Session(); store != nil || err == nil { + t.FailNow() + } +} + +func TestContext_Session2(t *testing.T) { + c := NewContext() + c.Input.CruSession = &session.MemSessionStore{} + + if store, err := c.Session(); store == nil || err != nil { + t.FailNow() + } +} + +func TestSetCookie(t *testing.T) { + type cookie struct { + Name string + Value string + MaxAge int64 + Path string + Domain string + Secure bool + HttpOnly bool + SameSite string + } + type testItem struct { + item cookie + want string + } + cases := []struct { + request string + valueGp []testItem + }{ + {"/", []testItem{{cookie{"name", "value", -1, "/", "", false, false, "Strict"}, "name=value; Max-Age=0; Path=/; SameSite=Strict"}}}, + {"/", []testItem{{cookie{"name", "value", -1, "/", "", false, false, "Lax"}, "name=value; Max-Age=0; Path=/; SameSite=Lax"}}}, + {"/", []testItem{{cookie{"name", "value", -1, "/", "", false, false, "None"}, "name=value; Max-Age=0; Path=/; SameSite=None"}}}, + {"/", []testItem{{cookie{"name", "value", -1, "/", "", false, false, ""}, "name=value; Max-Age=0; Path=/"}}}, + } + for _, c := range cases { + r, _ := http.NewRequest("GET", c.request, nil) + output := NewOutput() + output.Context = NewContext() + output.Context.Reset(httptest.NewRecorder(), r) + for _, item := range c.valueGp { + params := item.item + others := []interface{}{params.MaxAge, params.Path, params.Domain, params.Secure, params.HttpOnly, params.SameSite} + output.Context.SetCookie(params.Name, params.Value, others...) + got := output.Context.ResponseWriter.Header().Get("Set-Cookie") + if got != item.want { + t.Fatalf("SetCookie error,should be:\n%v \ngot:\n%v", item.want, got) + } + } + } +} diff --git a/server/web/context/form.go b/server/web/context/form.go new file mode 100644 index 0000000000..caa58af321 --- /dev/null +++ b/server/web/context/form.go @@ -0,0 +1,178 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/url" + "reflect" + "strconv" + "strings" + "time" +) + +var ( + sliceOfInts = reflect.TypeOf([]int(nil)) + sliceOfStrings = reflect.TypeOf([]string(nil)) +) + +// ParseForm will parse form values to struct via tag. +// Support for anonymous struct. +func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { + for i := 0; i < objT.NumField(); i++ { + fieldV := objV.Field(i) + if !fieldV.CanSet() { + continue + } + + fieldT := objT.Field(i) + if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { + err := parseFormToStruct(form, fieldT.Type, fieldV) + if err != nil { + return err + } + continue + } + + tag, ok := formTagName(fieldT) + if !ok { + continue + } + + value, ok := formValue(tag, form, fieldT) + if !ok { + continue + } + + switch fieldT.Type.Kind() { + case reflect.Bool: + b, err := parseFormBoolValue(value) + if err != nil { + return err + } + fieldV.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + x, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + fieldV.SetInt(x) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + fieldV.SetUint(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + fieldV.SetFloat(x) + case reflect.Interface: + fieldV.Set(reflect.ValueOf(value)) + case reflect.String: + fieldV.SetString(value) + case reflect.Struct: + if fieldT.Type.String() == "time.Time" { + t, err := parseFormTime(value) + if err != nil { + return err + } + fieldV.Set(reflect.ValueOf(t)) + } + case reflect.Slice: + if fieldT.Type == sliceOfInts { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + val, err := strconv.Atoi(formVals[i]) + if err != nil { + return err + } + fieldV.Index(i).SetInt(int64(val)) + } + } else if fieldT.Type == sliceOfStrings { + formVals := form[tag] + fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) + for i := 0; i < len(formVals); i++ { + fieldV.Index(i).SetString(formVals[i]) + } + } + } + } + return nil +} + +// nolint +func parseFormTime(value string) (time.Time, error) { + var pattern string + if len(value) >= 25 { + value = value[:25] + pattern = time.RFC3339 + } else if strings.HasSuffix(strings.ToUpper(value), "Z") { + pattern = time.RFC3339 + } else if len(value) >= 19 { + if strings.Contains(value, "T") { + pattern = formatDateTimeT + } else { + pattern = formatDateTime + } + value = value[:19] + } else if len(value) >= 10 { + if len(value) > 10 { + value = value[:10] + } + pattern = formatDate + } else if len(value) >= 8 { + if len(value) > 8 { + value = value[:8] + } + pattern = formatTime + } + return time.ParseInLocation(pattern, value, time.Local) +} + +func parseFormBoolValue(value string) (bool, error) { + if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { + return true, nil + } + if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { + return false, nil + } + return strconv.ParseBool(value) +} + +// nolint +func formTagName(fieldT reflect.StructField) (string, bool) { + tags := strings.Split(fieldT.Tag.Get("form"), ",") + var tag string + if len(tags) == 0 || tags[0] == "" { + tag = fieldT.Name + } else if tags[0] == "-" { + return "", false + } else { + tag = tags[0] + } + return tag, true +} + +func formValue(tag string, form url.Values, fieldT reflect.StructField) (string, bool) { + formValues := form[tag] + if len(formValues) == 0 { + defaultValue := fieldT.Tag.Get("default") + return defaultValue, defaultValue != "" + } + return formValues[0], true +} diff --git a/server/web/context/form_test.go b/server/web/context/form_test.go new file mode 100644 index 0000000000..4ebdac1a84 --- /dev/null +++ b/server/web/context/form_test.go @@ -0,0 +1,88 @@ +// Copyright 2024 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package context + +import ( + "net/url" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFormValue(t *testing.T) { + typ := reflect.TypeOf(TestStruct{}) + defField, _ := typ.FieldByName("DefaultField") + noDefField, _ := typ.FieldByName("NoDefaultField") + testCases := []struct { + name string + tag string + + form url.Values + field reflect.StructField + + wantRes string + wantOk bool + }{ + { + name: "use value", + tag: "defaultField", + field: defField, + form: map[string][]string{ + "defaultField": {"abc", "bcd"}, + }, + wantRes: "abc", + wantOk: true, + }, + { + name: "empty value", + tag: "defaultField", + field: defField, + form: map[string][]string{ + "defaultField": {"", "bcd"}, + }, + wantRes: "", + wantOk: true, + }, + { + name: "use default value", + tag: "defaultField", + field: defField, + form: map[string][]string{}, + wantRes: "Tom", + wantOk: true, + }, + { + name: "no value", + tag: "no", + field: noDefField, + form: map[string][]string{}, + wantOk: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val, ok := formValue(tc.tag, tc.form, tc.field) + assert.Equal(t, tc.wantRes, val) + assert.Equal(t, tc.wantOk, ok) + }) + } +} + +type TestStruct struct { + DefaultField string `form:"defaultField" default:"Tom"` + NoDefaultField string `form:"noDefaultField"` +} diff --git a/context/input.go b/server/web/context/input.go similarity index 88% rename from context/input.go rename to server/web/context/input.go index 8195215889..3280d68fa5 100644 --- a/context/input.go +++ b/server/web/context/input.go @@ -19,7 +19,6 @@ import ( "compress/gzip" "errors" "io" - "io/ioutil" "net" "net/http" "net/url" @@ -27,8 +26,9 @@ import ( "regexp" "strconv" "strings" + "sync" - "github.com/astaxie/beego/session" + "github.com/beego/beego/v2/server/web/session" ) // Regexes for checking the accept headers @@ -42,19 +42,20 @@ var ( ) // BeegoInput operates the http request header, data, cookie and body. -// it also contains router params and current session. +// Contains router params and current session. type BeegoInput struct { Context *Context CruSession session.Store pnames []string pvalues []string data map[interface{}]interface{} // store some values in this context when calling context in filter or controller. + dataLock sync.RWMutex RequestBody []byte RunMethod string RunController reflect.Type } -// NewInput return BeegoInput generated by Context. +// NewInput returns the BeegoInput generated by context. func NewInput() *BeegoInput { return &BeegoInput{ pnames: make([]string, 0, maxParam), @@ -63,37 +64,39 @@ func NewInput() *BeegoInput { } } -// Reset init the BeegoInput +// Reset initializes the BeegoInput func (input *BeegoInput) Reset(ctx *Context) { input.Context = ctx input.CruSession = nil input.pnames = input.pnames[:0] input.pvalues = input.pvalues[:0] + input.dataLock.Lock() input.data = nil + input.dataLock.Unlock() input.RequestBody = []byte{} } -// Protocol returns request protocol name, such as HTTP/1.1 . +// Protocol returns the request protocol name, such as HTTP/1.1 . func (input *BeegoInput) Protocol() string { return input.Context.Request.Proto } -// URI returns full request url with query string, fragment. +// URI returns the full request url with query, string and fragment. func (input *BeegoInput) URI() string { return input.Context.Request.RequestURI } -// URL returns request url path (without query string, fragment). +// URL returns the request url path (without query, string and fragment). func (input *BeegoInput) URL() string { return input.Context.Request.URL.Path } -// Site returns base site url as scheme://domain type. +// Site returns the base site url as scheme://domain type. func (input *BeegoInput) Site() string { return input.Scheme() + "://" + input.Domain() } -// Scheme returns request scheme as "http" or "https". +// Scheme returns the request scheme as "http" or "https". func (input *BeegoInput) Scheme() string { if scheme := input.Header("X-Forwarded-Proto"); scheme != "" { return scheme @@ -107,14 +110,13 @@ func (input *BeegoInput) Scheme() string { return "https" } -// Domain returns host name. -// Alias of Host method. +// Domain returns the host name (alias of host method) func (input *BeegoInput) Domain() string { return input.Host() } -// Host returns host name. -// if no host info in request, return localhost. +// Host returns the host name. +// If no host info in request, return localhost. func (input *BeegoInput) Host() string { if input.Context.Request.Host != "" { if hostPart, _, err := net.SplitHostPort(input.Context.Request.Host); err == nil { @@ -130,7 +132,7 @@ func (input *BeegoInput) Method() string { return input.Context.Request.Method } -// Is returns boolean of this request is on given method, such as Is("POST"). +// Is returns the boolean value of this request is on given method, such as Is("POST"). func (input *BeegoInput) Is(method string) bool { return input.Method() == method } @@ -150,7 +152,7 @@ func (input *BeegoInput) IsHead() bool { return input.Is("HEAD") } -// IsOptions Is this a OPTIONS method request? +// IsOptions Is this an OPTIONS method request? func (input *BeegoInput) IsOptions() bool { return input.Is("OPTIONS") } @@ -170,7 +172,7 @@ func (input *BeegoInput) IsPatch() bool { return input.Is("PATCH") } -// IsAjax returns boolean of this request is generated by ajax. +// IsAjax returns boolean of is this request generated by ajax. func (input *BeegoInput) IsAjax() bool { return input.Header("X-Requested-With") == "XMLHttpRequest" } @@ -204,6 +206,7 @@ func (input *BeegoInput) AcceptsXML() bool { func (input *BeegoInput) AcceptsJSON() bool { return acceptsJSONRegex.MatchString(input.Header("Accept")) } + // AcceptsYAML Checks if request accepts json response func (input *BeegoInput) AcceptsYAML() bool { return acceptsYAMLRegex.MatchString(input.Header("Accept")) @@ -246,7 +249,7 @@ func (input *BeegoInput) Refer() string { } // SubDomains returns sub domain string. -// if aa.bb.domain.com, returns aa.bb . +// if aa.bb.domain.com, returns aa.bb func (input *BeegoInput) SubDomains() string { parts := strings.Split(input.Host(), ".") if len(parts) >= 3 { @@ -279,6 +282,11 @@ func (input *BeegoInput) ParamsLen() int { func (input *BeegoInput) Param(key string) string { for i, v := range input.pnames { if v == key && i <= len(input.pvalues) { + // we cannot use url.PathEscape(input.pvalues[i]) + // for example, if the value is /a/b + // after url.PathEscape(input.pvalues[i]), the value is %2Fa%2Fb + // However, the value is used in ControllerRegister.ServeHTTP + // and split by "/", so function crash... return input.pvalues[i] } } @@ -296,7 +304,7 @@ func (input *BeegoInput) Params() map[string]string { return m } -// SetParam will set the param with key and value +// SetParam sets the param with key and value func (input *BeegoInput) SetParam(key, val string) { // check if already exists for i, v := range input.pnames { @@ -309,9 +317,8 @@ func (input *BeegoInput) SetParam(key, val string) { input.pnames = append(input.pnames, key) } -// ResetParams clears any of the input's Params -// This function is used to clear parameters so they may be reset between filter -// passes. +// ResetParams clears any of the input's params +// Used to clear parameters so they may be reset between filter passes. func (input *BeegoInput) ResetParams() { input.pnames = input.pnames[:0] input.pvalues = input.pvalues[:0] @@ -323,8 +330,14 @@ func (input *BeegoInput) Query(key string) string { return val } if input.Context.Request.Form == nil { - input.Context.Request.ParseForm() + input.dataLock.Lock() + if input.Context.Request.Form == nil { + input.Context.Request.ParseForm() + } + input.dataLock.Unlock() } + input.dataLock.RLock() + defer input.dataLock.RUnlock() return input.Context.Request.Form.Get(key) } @@ -347,7 +360,7 @@ func (input *BeegoInput) Cookie(key string) string { // Session returns current session item value by a given key. // if non-existed, return nil. func (input *BeegoInput) Session(key interface{}) interface{} { - return input.CruSession.Get(key) + return input.CruSession.Get(nil, key) } // CopyBody returns the raw request body data as bytes. @@ -363,20 +376,22 @@ func (input *BeegoInput) CopyBody(MaxMemory int64) []byte { if err != nil { return nil } - requestbody, _ = ioutil.ReadAll(reader) + requestbody, _ = io.ReadAll(reader) } else { - requestbody, _ = ioutil.ReadAll(safe) + requestbody, _ = io.ReadAll(safe) } input.Context.Request.Body.Close() bf := bytes.NewBuffer(requestbody) - input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, ioutil.NopCloser(bf), MaxMemory) + input.Context.Request.Body = http.MaxBytesReader(input.Context.ResponseWriter, io.NopCloser(bf), MaxMemory) input.RequestBody = requestbody return requestbody } -// Data return the implicit data in the input +// Data returns the implicit data in the input func (input *BeegoInput) Data() map[interface{}]interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() if input.data == nil { input.data = make(map[interface{}]interface{}) } @@ -385,6 +400,8 @@ func (input *BeegoInput) Data() map[interface{}]interface{} { // GetData returns the stored data in this context. func (input *BeegoInput) GetData(key interface{}) interface{} { + input.dataLock.Lock() + defer input.dataLock.Unlock() if v, ok := input.data[key]; ok { return v } @@ -392,18 +409,20 @@ func (input *BeegoInput) GetData(key interface{}) interface{} { } // SetData stores data with given key in this context. -// This data are only available in this context. +// This data is only available in this context. func (input *BeegoInput) SetData(key, val interface{}) { + input.dataLock.Lock() + defer input.dataLock.Unlock() if input.data == nil { input.data = make(map[interface{}]interface{}) } input.data[key] = val } -// ParseFormOrMulitForm parseForm or parseMultiForm based on Content-type -func (input *BeegoInput) ParseFormOrMulitForm(maxMemory int64) error { +// ParseFormOrMultiForm parseForm or parseMultiForm based on Content-type +func (input *BeegoInput) ParseFormOrMultiForm(maxMemory int64) error { // Parse the body depending on the content type. - if strings.Contains(input.Header("Content-Type"), "multipart/form-data") { + if input.IsUpload() { if err := input.Context.Request.ParseMultipartForm(maxMemory); err != nil { return errors.New("Error parsing request body:" + err.Error()) } diff --git a/context/input_test.go b/server/web/context/input_test.go similarity index 94% rename from context/input_test.go rename to server/web/context/input_test.go index db812a0f03..058dc80545 100644 --- a/context/input_test.go +++ b/server/web/context/input_test.go @@ -71,11 +71,13 @@ func TestBind(t *testing.T) { {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}}, {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}}, - {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", + { + "/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02", []testItem{{"human", []Human{}, []Human{ {ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"}, {ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"}, - }}}}, + }}}, + }, { "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie", @@ -203,5 +205,15 @@ func TestParams(t *testing.T) { if val := inp.Param("p2"); val != "val2_ver2" { t.Fatalf("Input.Param wrong value: %s, expected %s", val, "val1_ver2") } +} +func BenchmarkQuery(b *testing.B) { + beegoInput := NewInput() + beegoInput.Context = NewContext() + beegoInput.Context.Request, _ = http.NewRequest("POST", "http://www.example.com/?q=foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + beegoInput.Query("q") + } + }) } diff --git a/context/output.go b/server/web/context/output.go similarity index 81% rename from context/output.go rename to server/web/context/output.go index 238dcf45ef..b09c792b9b 100644 --- a/context/output.go +++ b/server/web/context/output.go @@ -31,7 +31,8 @@ import ( "strings" "time" - yaml "gopkg.in/yaml.v2" + "google.golang.org/protobuf/proto" + "gopkg.in/yaml.v3" ) // BeegoOutput does work for sending response header. @@ -42,12 +43,12 @@ type BeegoOutput struct { } // NewOutput returns new BeegoOutput. -// it contains nothing now. +// Empty when initialized func NewOutput() *BeegoOutput { return &BeegoOutput{} } -// Reset init BeegoOutput +// Reset initializes BeegoOutput func (output *BeegoOutput) Reset(ctx *Context) { output.Context = ctx output.Status = 0 @@ -58,12 +59,12 @@ func (output *BeegoOutput) Header(key, val string) { output.Context.ResponseWriter.Header().Set(key, val) } -// Body sets response body content. -// if EnableGzip, compress content string. -// it sends out response body directly. +// Body sets the response body content. +// if EnableGzip, content is compressed. +// Sends out response body directly. func (output *BeegoOutput) Body(content []byte) error { var encoding string - var buf = &bytes.Buffer{} + buf := &bytes.Buffer{} if output.EnableGzip { encoding = ParseEncoding(output.Context.Request) } @@ -85,13 +86,13 @@ func (output *BeegoOutput) Body(content []byte) error { return nil } -// Cookie sets cookie value via given key. -// others are ordered as cookie's max age time, path,domain, secure and httponly. +// Cookie sets a cookie value via given key. +// others: used to set a cookie's max age time, path,domain, secure and httponly. func (output *BeegoOutput) Cookie(name string, value string, others ...interface{}) { var b bytes.Buffer fmt.Fprintf(&b, "%s=%s", sanitizeName(name), sanitizeValue(value)) - //fix cookie not work in IE + // fix cookie not work in IE if len(others) > 0 { var maxAge int64 @@ -117,13 +118,13 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface // can use nil skip set // default "/" + tmpPath := "/" if len(others) > 1 { if v, ok := others[1].(string); ok && len(v) > 0 { - fmt.Fprintf(&b, "; Path=%s", sanitizeValue(v)) + tmpPath = sanitizeValue(v) } - } else { - fmt.Fprintf(&b, "; Path=%s", "/") } + fmt.Fprintf(&b, "; Path=%s", tmpPath) // default empty if len(others) > 2 { @@ -155,6 +156,13 @@ func (output *BeegoOutput) Cookie(name string, value string, others ...interface } } + // default empty + if len(others) > 5 { + if v, ok := others[5].(string); ok && len(v) > 0 { + fmt.Fprintf(&b, "; SameSite=%s", sanitizeValue(v)) + } + } + output.Context.ResponseWriter.Header().Add("Set-Cookie", b.String()) } @@ -183,7 +191,7 @@ func errorRenderer(err error) Renderer { }) } -// JSON writes json to response body. +// JSON writes json to the response body. // if encoding is true, it converts utf-8 to \u0000 type. func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) error { output.Header("Content-Type", "application/json; charset=utf-8") @@ -204,7 +212,7 @@ func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, encoding bool) return output.Body(content) } -// YAML writes yaml to response body. +// YAML writes yaml to the response body. func (output *BeegoOutput) YAML(data interface{}) error { output.Header("Content-Type", "application/x-yaml; charset=utf-8") var content []byte @@ -217,7 +225,20 @@ func (output *BeegoOutput) YAML(data interface{}) error { return output.Body(content) } -// JSONP writes jsonp to response body. +// Proto writes protobuf to the response body. +func (output *BeegoOutput) Proto(data proto.Message) error { + output.Header("Content-Type", "application/x-protobuf; charset=utf-8") + var content []byte + var err error + content, err = proto.Marshal(data) + if err != nil { + http.Error(output.Context.ResponseWriter, err.Error(), http.StatusInternalServerError) + return err + } + return output.Body(content) +} + +// JSONP writes jsonp to the response body. func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { output.Header("Content-Type", "application/javascript; charset=utf-8") var content []byte @@ -243,7 +264,7 @@ func (output *BeegoOutput) JSONP(data interface{}, hasIndent bool) error { return output.Body(callbackContent.Bytes()) } -// XML writes xml string to response body. +// XML writes xml string to the response body. func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { output.Header("Content-Type", "application/xml; charset=utf-8") var content []byte @@ -260,21 +281,21 @@ func (output *BeegoOutput) XML(data interface{}, hasIndent bool) error { return output.Body(content) } -// ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header -func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) { +// ServeFormatted serves YAML, XML or JSON, depending on the value of the Accept header +func (output *BeegoOutput) ServeFormatted(data interface{}, hasIndent bool, hasEncode ...bool) error { accept := output.Context.Input.Header("Accept") switch accept { case ApplicationYAML: - output.YAML(data) + return output.YAML(data) case ApplicationXML, TextXML: - output.XML(data, hasIndent) + return output.XML(data, hasIndent) default: - output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) + return output.JSON(data, hasIndent, len(hasEncode) > 0 && hasEncode[0]) } } // Download forces response for download file. -// it prepares the download response header automatically. +// Prepares the download response header automatically. func (output *BeegoOutput) Download(file string, filename ...string) { // check get file error, file not found or other error. if _, err := os.Stat(file); err != nil { @@ -288,7 +309,7 @@ func (output *BeegoOutput) Download(file string, filename ...string) { } else { fName = filepath.Base(file) } - //https://tools.ietf.org/html/rfc6266#section-4.3 + // https://tools.ietf.org/html/rfc6266#section-4.3 fn := url.PathEscape(fName) if fName == fn { fn = "filename=" + fn @@ -323,61 +344,61 @@ func (output *BeegoOutput) ContentType(ext string) { } } -// SetStatus sets response status code. -// It writes response header directly. +// SetStatus sets the response status code. +// Writes response header directly. func (output *BeegoOutput) SetStatus(status int) { output.Status = status } -// IsCachable returns boolean of this request is cached. +// IsCachable returns boolean of if this request is cached. // HTTP 304 means cached. func (output *BeegoOutput) IsCachable() bool { return output.Status >= 200 && output.Status < 300 || output.Status == 304 } -// IsEmpty returns boolean of this request is empty. +// IsEmpty returns boolean of if this request is empty. // HTTP 201,204 and 304 means empty. func (output *BeegoOutput) IsEmpty() bool { return output.Status == 201 || output.Status == 204 || output.Status == 304 } -// IsOk returns boolean of this request runs well. +// IsOk returns boolean of if this request was ok. // HTTP 200 means ok. func (output *BeegoOutput) IsOk() bool { return output.Status == 200 } -// IsSuccessful returns boolean of this request runs successfully. +// IsSuccessful returns boolean of this request was successful. // HTTP 2xx means ok. func (output *BeegoOutput) IsSuccessful() bool { return output.Status >= 200 && output.Status < 300 } -// IsRedirect returns boolean of this request is redirection header. +// IsRedirect returns boolean of if this request is redirected. // HTTP 301,302,307 means redirection. func (output *BeegoOutput) IsRedirect() bool { return output.Status == 301 || output.Status == 302 || output.Status == 303 || output.Status == 307 } -// IsForbidden returns boolean of this request is forbidden. +// IsForbidden returns boolean of if this request is forbidden. // HTTP 403 means forbidden. func (output *BeegoOutput) IsForbidden() bool { return output.Status == 403 } -// IsNotFound returns boolean of this request is not found. +// IsNotFound returns boolean of if this request is not found. // HTTP 404 means not found. func (output *BeegoOutput) IsNotFound() bool { return output.Status == 404 } -// IsClientError returns boolean of this request client sends error data. +// IsClientError returns boolean of if this request client sends error data. // HTTP 4xx means client error. func (output *BeegoOutput) IsClientError() bool { return output.Status >= 400 && output.Status < 500 } -// IsServerError returns boolean of this server handler errors. +// IsServerError returns boolean of if this server handler errors. // HTTP 5xx means server internal error. func (output *BeegoOutput) IsServerError() bool { return output.Status >= 500 && output.Status < 600 @@ -404,5 +425,5 @@ func stringsToJSON(str string) string { // Session sets session item value with given key. func (output *BeegoOutput) Session(name interface{}, value interface{}) { - output.Context.Input.CruSession.Set(name, value) + output.Context.Input.CruSession.Set(nil, name, value) } diff --git a/context/param/conv.go b/server/web/context/param/conv.go similarity index 95% rename from context/param/conv.go rename to server/web/context/param/conv.go index c200e0088d..eecb6acbd9 100644 --- a/context/param/conv.go +++ b/server/web/context/param/conv.go @@ -4,8 +4,8 @@ import ( "fmt" "reflect" - beecontext "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" + "github.com/beego/beego/v2/core/logs" + beecontext "github.com/beego/beego/v2/server/web/context" ) // ConvertParams converts http method params to values that will be passed to the method controller as arguments diff --git a/context/param/methodparams.go b/server/web/context/param/methodparams.go similarity index 85% rename from context/param/methodparams.go rename to server/web/context/param/methodparams.go index cd6708a27f..22ff0e4369 100644 --- a/context/param/methodparams.go +++ b/server/web/context/param/methodparams.go @@ -5,7 +5,7 @@ import ( "strings" ) -//MethodParam keeps param information to be auto passed to controller methods +// MethodParam keeps param information to be auto passed to controller methods type MethodParam struct { name string in paramType @@ -22,7 +22,7 @@ const ( header ) -//New creates a new MethodParam with name and specific options +// New creates a new MethodParam with name and specific options func New(name string, opts ...MethodParamOption) *MethodParam { return newParam(name, nil, opts) } @@ -35,7 +35,7 @@ func newParam(name string, parser paramParser, opts []MethodParamOption) (param return } -//Make creates an array of MethodParmas or an empty array +// Make creates an array of MethodParmas or an empty array func Make(list ...*MethodParam) []*MethodParam { if len(list) > 0 { return list diff --git a/context/param/options.go b/server/web/context/param/options.go similarity index 100% rename from context/param/options.go rename to server/web/context/param/options.go diff --git a/context/param/parsers.go b/server/web/context/param/parsers.go similarity index 93% rename from context/param/parsers.go rename to server/web/context/param/parsers.go index 421aecf08d..99ad9275e3 100644 --- a/context/param/parsers.go +++ b/server/web/context/param/parsers.go @@ -18,7 +18,7 @@ func getParser(param *MethodParam, t reflect.Type) paramParser { reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return intParser{} case reflect.Slice: - if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string + if t.Elem().Kind() == reflect.Uint8 { // treat []byte as string return stringParser{} } if param.in == body { @@ -55,29 +55,25 @@ func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error return f(value, toType) } -type boolParser struct { -} +type boolParser struct{} func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) { return strconv.ParseBool(value) } -type stringParser struct { -} +type stringParser struct{} func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) { return value, nil } -type intParser struct { -} +type intParser struct{} func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) { return strconv.Atoi(value) } -type floatParser struct { -} +type floatParser struct{} func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) { if toType.Kind() == reflect.Float32 { @@ -90,8 +86,7 @@ func (p floatParser) parse(value string, toType reflect.Type) (interface{}, erro return strconv.ParseFloat(value, 64) } -type timeParser struct { -} +type timeParser struct{} func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) { result, err = time.Parse(time.RFC3339, value) @@ -101,8 +96,7 @@ func (p timeParser) parse(value string, toType reflect.Type) (result interface{} return } -type jsonParser struct { -} +type jsonParser struct{} func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) { pResult := reflect.New(toType) diff --git a/context/param/parsers_test.go b/server/web/context/param/parsers_test.go similarity index 90% rename from context/param/parsers_test.go rename to server/web/context/param/parsers_test.go index 7065a28ed5..931d88f604 100644 --- a/context/param/parsers_test.go +++ b/server/web/context/param/parsers_test.go @@ -1,8 +1,10 @@ package param -import "testing" -import "reflect" -import "time" +import ( + "reflect" + "testing" + "time" +) type testDefinition struct { strValue string @@ -11,47 +13,45 @@ type testDefinition struct { } func Test_Parsers(t *testing.T) { - - //ints + // ints checkParser(testDefinition{"1", 1, intParser{}}, t) checkParser(testDefinition{"-1", int64(-1), intParser{}}, t) checkParser(testDefinition{"1", uint64(1), intParser{}}, t) - //floats + // floats checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t) checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t) - //strings + // strings checkParser(testDefinition{"AB", "AB", stringParser{}}, t) checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t) - //bools + // bools checkParser(testDefinition{"true", true, boolParser{}}, t) checkParser(testDefinition{"0", false, boolParser{}}, t) - //timeParser + // timeParser checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t) checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t) - //json + // json checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct { X int Y string }{5, "Z"}, jsonParser{}}, t) - //slice in query is parsed as comma delimited + // slice in query is parsed as comma delimited checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t) - //slice in body is parsed as json + // slice in body is parsed as json checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body}) - //pointers - var someInt = 1 + // pointers + someInt := 1 checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t) - var someStruct = struct{ X int }{5} + someStruct := struct{ X int }{5} checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t) - } func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) { diff --git a/context/renderer.go b/server/web/context/renderer.go similarity index 77% rename from context/renderer.go rename to server/web/context/renderer.go index 36a7cb53fe..5a0783324f 100644 --- a/context/renderer.go +++ b/server/web/context/renderer.go @@ -1,6 +1,6 @@ package context -// Renderer defines an http response renderer +// Renderer defines a http response renderer type Renderer interface { Render(ctx *Context) } diff --git a/context/response.go b/server/web/context/response.go similarity index 66% rename from context/response.go rename to server/web/context/response.go index 9c3c715a2d..86b2c0b887 100644 --- a/context/response.go +++ b/server/web/context/response.go @@ -1,27 +1,26 @@ package context import ( - "strconv" - "net/http" + "strconv" ) const ( - //BadRequest indicates http error 400 + // BadRequest indicates HTTP error 400 BadRequest StatusCode = http.StatusBadRequest - //NotFound indicates http error 404 + // NotFound indicates HTTP error 404 NotFound StatusCode = http.StatusNotFound ) -// StatusCode sets the http response status code +// StatusCode sets the HTTP response status code type StatusCode int func (s StatusCode) Error() string { return strconv.Itoa(int(s)) } -// Render sets the http status code +// Render sets the HTTP status code func (s StatusCode) Render(ctx *Context) { ctx.Output.SetStatus(int(s)) } diff --git a/controller.go b/server/web/controller.go similarity index 77% rename from controller.go rename to server/web/controller.go index 0e8853b31e..bebe01e8fc 100644 --- a/controller.go +++ b/server/web/controller.go @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "bytes" + context2 "context" "errors" "fmt" "html/template" @@ -27,10 +28,13 @@ import ( "reflect" "strconv" "strings" + "sync" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/context/param" - "github.com/astaxie/beego/session" + "google.golang.org/protobuf/proto" + + "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/context/param" + "github.com/beego/beego/v2/server/web/session" ) var ( @@ -38,8 +42,21 @@ var ( ErrAbort = errors.New("user stop run") // GlobalControllerRouter store comments with controller. pkgpath+controller:comments GlobalControllerRouter = make(map[string][]ControllerComments) + copyBufferPool sync.Pool +) + +const ( + bytePerKb = 1024 + copyBufferKb = 32 + filePerm = 0o666 ) +func init() { + copyBufferPool.New = func() interface{} { + return make([]byte, bytePerKb*copyBufferKb) + } +} + // ControllerFilter store the filter for controller type ControllerFilter struct { Pattern string @@ -79,9 +96,11 @@ type ControllerComments struct { // ControllerCommentsSlice implements the sort interface type ControllerCommentsSlice []ControllerComments -func (p ControllerCommentsSlice) Len() int { return len(p) } +func (p ControllerCommentsSlice) Len() int { return len(p) } + func (p ControllerCommentsSlice) Less(i, j int) bool { return p[i].Router < p[j].Router } -func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +func (p ControllerCommentsSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // Controller defines some basic http request handler operations, such as // http context, template and view, session and xsrf. @@ -106,9 +125,9 @@ type Controller struct { EnableRender bool // xsrf data + EnableXSRF bool _xsrfToken string XSRFExpire int - EnableXSRF bool // session CruSession session.Store @@ -224,6 +243,37 @@ func (c *Controller) HandlerFunc(fnname string) bool { // URLMapping register the internal Controller router. func (c *Controller) URLMapping() {} +// Bind if the content type is form, we read data from form +// otherwise, read data from request body +func (c *Controller) Bind(obj interface{}) error { + return c.Ctx.Bind(obj) +} + +// BindYAML only read data from http request body +func (c *Controller) BindYAML(obj interface{}) error { + return c.Ctx.BindYAML(obj) +} + +// BindForm read data from form +func (c *Controller) BindForm(obj interface{}) error { + return c.Ctx.BindForm(obj) +} + +// BindJSON only read data from http request body +func (c *Controller) BindJSON(obj interface{}) error { + return c.Ctx.BindJSON(obj) +} + +// BindProtobuf only read data from http request body +func (c *Controller) BindProtobuf(obj proto.Message) error { + return c.Ctx.BindProtobuf(obj) +} + +// BindXML only read data from http request body +func (c *Controller) BindXML(obj interface{}) error { + return c.Ctx.BindXML(obj) +} + // Mapping the method to function func (c *Controller) Mapping(method string, fn func()) { c.methodMapping[method] = fn @@ -249,33 +299,34 @@ func (c *Controller) Render() error { // RenderString returns the rendered template string. Do not send out response. func (c *Controller) RenderString() (string, error) { b, e := c.RenderBytes() + if e != nil { + return "", e + } return string(b), e } // RenderBytes returns the bytes of rendered template string. Do not send out response. func (c *Controller) RenderBytes() ([]byte, error) { buf, err := c.renderTemplate() - //if the controller has set layout, then first get the tplName's content set the content to the layout + // if the controller has set layout, then first get the tplName's content set the content to the layout if err == nil && c.Layout != "" { c.Data["LayoutContent"] = template.HTML(buf.String()) - if c.LayoutSections != nil { - for sectionName, sectionTpl := range c.LayoutSections { - if sectionTpl == "" { - c.Data[sectionName] = "" - continue - } - buf.Reset() - err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) - if err != nil { - return nil, err - } - c.Data[sectionName] = template.HTML(buf.String()) + for sectionName, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + c.Data[sectionName] = "" + continue + } + buf.Reset() + err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data) + if err != nil { + return nil, err } + c.Data[sectionName] = template.HTML(buf.String()) } buf.Reset() - ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) + err = ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath(), c.Data) } return buf.Bytes(), err } @@ -292,13 +343,11 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) { buildFiles := []string{c.TplName} if c.Layout != "" { buildFiles = append(buildFiles, c.Layout) - if c.LayoutSections != nil { - for _, sectionTpl := range c.LayoutSections { - if sectionTpl == "" { - continue - } - buildFiles = append(buildFiles, sectionTpl) + for _, sectionTpl := range c.LayoutSections { + if sectionTpl == "" { + continue } + buildFiles = append(buildFiles, sectionTpl) } } BuildTemplate(c.viewPath(), buildFiles...) @@ -343,9 +392,9 @@ func (c *Controller) Abort(code string) { // CustomAbort stops controller handler and show the error data, it's similar Aborts, but support status code and body. func (c *Controller) CustomAbort(status int, body string) { + c.Ctx.Output.Status = status // first panic from ErrorMaps, it is user defined error functions. if _, ok := ErrorMaps[body]; ok { - c.Ctx.Output.Status = status panic(body) } // last panic user string @@ -371,51 +420,74 @@ func (c *Controller) URLFor(endpoint string, values ...interface{}) string { return URLFor(endpoint, values...) } +func (c *Controller) JSONResp(data interface{}) error { + return c.Ctx.JSONResp(data) +} + +func (c *Controller) XMLResp(data interface{}) error { + return c.Ctx.XMLResp(data) +} + +func (c *Controller) YamlResp(data interface{}) error { + return c.Ctx.YamlResp(data) +} + +// Resp sends response based on the Accept Header +// By default response will be in JSON +// it's different from ServeXXX methods +// because we don't store the data to Data field +func (c *Controller) Resp(data interface{}) error { + return c.Ctx.Resp(data) +} + // ServeJSON sends a json response with encoding charset. -func (c *Controller) ServeJSON(encoding ...bool) { +func (c *Controller) ServeJSON(encoding ...bool) error { var ( hasIndent = BConfig.RunMode != PROD hasEncoding = len(encoding) > 0 && encoding[0] ) - c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) + return c.Ctx.Output.JSON(c.Data["json"], hasIndent, hasEncoding) } // ServeJSONP sends a jsonp response. -func (c *Controller) ServeJSONP() { +func (c *Controller) ServeJSONP() error { hasIndent := BConfig.RunMode != PROD - c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) + return c.Ctx.Output.JSONP(c.Data["jsonp"], hasIndent) } // ServeXML sends xml response. -func (c *Controller) ServeXML() { +func (c *Controller) ServeXML() error { hasIndent := BConfig.RunMode != PROD - c.Ctx.Output.XML(c.Data["xml"], hasIndent) + return c.Ctx.Output.XML(c.Data["xml"], hasIndent) } // ServeYAML sends yaml response. -func (c *Controller) ServeYAML() { - c.Ctx.Output.YAML(c.Data["yaml"]) +func (c *Controller) ServeYAML() error { + return c.Ctx.Output.YAML(c.Data["yaml"]) } // ServeFormatted serve YAML, XML OR JSON, depending on the value of the Accept header -func (c *Controller) ServeFormatted(encoding ...bool) { +func (c *Controller) ServeFormatted(encoding ...bool) error { hasIndent := BConfig.RunMode != PROD hasEncoding := len(encoding) > 0 && encoding[0] - c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) + return c.Ctx.Output.ServeFormatted(c.Data, hasIndent, hasEncoding) } // Input returns the input data map from POST or PUT request body and query string. -func (c *Controller) Input() url.Values { +func (c *Controller) Input() (url.Values, error) { if c.Ctx.Request.Form == nil { - c.Ctx.Request.ParseForm() + err := c.Ctx.Request.ParseForm() + if err != nil { + return nil, err + } } - return c.Ctx.Request.Form + return c.Ctx.Request.Form, nil } // ParseForm maps input data map to obj struct. func (c *Controller) ParseForm(obj interface{}) error { - return ParseForm(c.Input(), obj) + return c.Ctx.BindForm(obj) } // GetString returns the input value by key string or the default value while it's present and input is blank @@ -437,7 +509,7 @@ func (c *Controller) GetStrings(key string, def ...[]string) []string { defv = def[0] } - if f := c.Input(); f == nil { + if f, err := c.Input(); f == nil || err != nil { return defv } else if vs := f[key]; len(vs) > 0 { return vs @@ -559,31 +631,33 @@ func (c *Controller) GetFile(key string) (multipart.File, *multipart.FileHeader, // GetFiles return multi-upload files // files, err:=c.GetFiles("myfiles") +// // if err != nil { // http.Error(w, err.Error(), http.StatusNoContent) // return // } -// for i, _ := range files { -// //for each fileheader, get a handle to the actual file -// file, err := files[i].Open() -// defer file.Close() -// if err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return -// } -// //create destination file making sure the path is writeable. -// dst, err := os.Create("upload/" + files[i].Filename) -// defer dst.Close() -// if err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return -// } -// //copy the uploaded file to the destination file -// if _, err := io.Copy(dst, file); err != nil { -// http.Error(w, err.Error(), http.StatusInternalServerError) -// return +// +// for i, _ := range files { +// //for each fileheader, get a handle to the actual file +// file, err := files[i].Open() +// defer file.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //create destination file making sure the path is writeable. +// dst, err := os.Create("upload/" + files[i].Filename) +// defer dst.Close() +// if err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } +// //copy the uploaded file to the destination file +// if _, err := io.Copy(dst, file); err != nil { +// http.Error(w, err.Error(), http.StatusInternalServerError) +// return +// } // } -// } func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { if files, ok := c.Ctx.Request.MultipartForm.File[key]; ok { return files, nil @@ -593,19 +667,36 @@ func (c *Controller) GetFiles(key string) ([]*multipart.FileHeader, error) { // SaveToFile saves uploaded file to new path. // it only operates the first one of mutil-upload form file field. -func (c *Controller) SaveToFile(fromfile, tofile string) error { - file, _, err := c.Ctx.Request.FormFile(fromfile) +func (c *Controller) SaveToFile(fromFile, toFile string) error { + buf := copyBufferPool.Get().([]byte) + defer copyBufferPool.Put(buf) + return c.SaveToFileWithBuffer(fromFile, toFile, buf) +} + +type onlyWriter struct { + io.Writer +} + +func (c *Controller) SaveToFileWithBuffer(fromFile string, toFile string, buf []byte) error { + src, _, err := c.Ctx.Request.FormFile(fromFile) if err != nil { return err } - defer file.Close() - f, err := os.OpenFile(tofile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + defer src.Close() + + dst, err := os.OpenFile(toFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, filePerm) if err != nil { return err } - defer f.Close() - io.Copy(f, file) - return nil + defer dst.Close() + + _, err = io.CopyBuffer(onlyWriter{dst}, src, buf) + if err != nil { + return err + } + + err = c.Ctx.Request.MultipartForm.RemoveAll() + return err } // StartSession starts session and load old session data info this controller. @@ -617,11 +708,11 @@ func (c *Controller) StartSession() session.Store { } // SetSession puts value into session. -func (c *Controller) SetSession(name interface{}, value interface{}) { +func (c *Controller) SetSession(name interface{}, value interface{}) error { if c.CruSession == nil { c.StartSession() } - c.CruSession.Set(name, value) + return c.CruSession.Set(context2.Background(), name, value) } // GetSession gets value from session. @@ -629,32 +720,38 @@ func (c *Controller) GetSession(name interface{}) interface{} { if c.CruSession == nil { c.StartSession() } - return c.CruSession.Get(name) + return c.CruSession.Get(context2.Background(), name) } // DelSession removes value from session. -func (c *Controller) DelSession(name interface{}) { +func (c *Controller) DelSession(name interface{}) error { if c.CruSession == nil { c.StartSession() } - c.CruSession.Delete(name) + return c.CruSession.Delete(context2.Background(), name) } // SessionRegenerateID regenerates session id for this session. // the session data have no changes. -func (c *Controller) SessionRegenerateID() { +func (c *Controller) SessionRegenerateID() error { if c.CruSession != nil { - c.CruSession.SessionRelease(c.Ctx.ResponseWriter) + c.CruSession.SessionRelease(context2.Background(), c.Ctx.ResponseWriter) } - c.CruSession = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) + var err error + c.CruSession, err = GlobalSessions.SessionRegenerateID(c.Ctx.ResponseWriter, c.Ctx.Request) c.Ctx.Input.CruSession = c.CruSession + return err } // DestroySession cleans session data and session cookie. -func (c *Controller) DestroySession() { - c.Ctx.Input.CruSession.Flush() +func (c *Controller) DestroySession() error { + err := c.Ctx.Input.CruSession.Flush(nil) + if err != nil { + return err + } c.Ctx.Input.CruSession = nil GlobalSessions.SessionDestroy(c.Ctx.ResponseWriter, c.Ctx.Request) + return nil } // IsAjax returns this request is ajax or not. diff --git a/server/web/controller_test.go b/server/web/controller_test.go new file mode 100644 index 0000000000..e27b54f8b4 --- /dev/null +++ b/server/web/controller_test.go @@ -0,0 +1,400 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "bytes" + "io" + "math" + "mime/multipart" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/beego/beego/v2/server/web/context" +) + +var ( + fileKey = "file1" + testFile = filepath.Join(currentWorkDir, "test/static/file.txt") + toFile = filepath.Join(currentWorkDir, "test/static/file2.txt") +) + +func TestGetInt(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt("age") + if val != 40 { + t.Errorf("TestGetInt expect 40,get %T,%v", val, val) + } +} + +func TestGetInt8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt8("age") + if val != 40 { + t.Errorf("TestGetInt8 expect 40,get %T,%v", val, val) + } + // Output: int8 +} + +func TestGetInt16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt16("age") + if val != 40 { + t.Errorf("TestGetInt16 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt32("age") + if val != 40 { + t.Errorf("TestGetInt32 expect 40,get %T,%v", val, val) + } +} + +func TestGetInt64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", "40") + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetInt64("age") + if val != 40 { + t.Errorf("TestGeetInt64 expect 40,get %T,%v", val, val) + } +} + +func TestGetUint8(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint8, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint8("age") + if val != math.MaxUint8 { + t.Errorf("TestGetUint8 expect %v,get %T,%v", math.MaxUint8, val, val) + } +} + +func TestGetUint16(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint16, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint16("age") + if val != math.MaxUint16 { + t.Errorf("TestGetUint16 expect %v,get %T,%v", math.MaxUint16, val, val) + } +} + +func TestGetUint32(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint32, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint32("age") + if val != math.MaxUint32 { + t.Errorf("TestGetUint32 expect %v,get %T,%v", math.MaxUint32, val, val) + } +} + +func TestGetUint64(t *testing.T) { + i := context.NewInput() + i.SetParam("age", strconv.FormatUint(math.MaxUint64, 10)) + ctx := &context.Context{Input: i} + ctrlr := Controller{Ctx: ctx} + val, _ := ctrlr.GetUint64("age") + if val != math.MaxUint64 { + t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) + } +} + +func TestAdditionalViewPaths(t *testing.T) { + tmpDir := os.TempDir() + dir1 := filepath.Join(tmpDir, "_beeTmp", "TestAdditionalViewPaths") + dir2 := filepath.Join(tmpDir, "_beeTmp2", "TestAdditionalViewPaths") + defer os.RemoveAll(dir1) + defer os.RemoveAll(dir2) + + dir1file := "file1.tpl" + dir2file := "file2.tpl" + + genFile := func(dir string, name string, content string) { + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0o777) + if f, err := os.Create(filepath.Join(dir, name)); err != nil { + t.Fatal(err) + } else { + defer f.Close() + f.WriteString(content) + f.Close() + } + } + genFile(dir1, dir1file, `
{{.Content}}
`) + genFile(dir2, dir2file, `{{.Content}}`) + + AddViewPath(dir1) + AddViewPath(dir2) + + ctrl := Controller{ + TplName: "file1.tpl", + ViewPath: dir1, + } + ctrl.Data = map[interface{}]interface{}{ + "Content": "value2", + } + if result, err := ctrl.RenderString(); err != nil { + t.Fatal(err) + } else { + if result != "
value2
" { + t.Fatalf("TestAdditionalViewPaths expect %s got %s", "
value2
", result) + } + } + + func() { + ctrl.TplName = "file2.tpl" + defer func() { + if r := recover(); r == nil { + t.Fatal("TestAdditionalViewPaths expected error") + } + }() + ctrl.RenderString() + }() + + ctrl.TplName = "file2.tpl" + ctrl.ViewPath = dir2 + ctrl.RenderString() +} + +func TestBindJson(t *testing.T) { + var s struct { + Foo string `json:"foo"` + } + header := map[string][]string{"Content-Type": {"application/json"}} + request := &http.Request{Header: header} + input := &context.BeegoInput{RequestBody: []byte(`{"foo": "FOO"}`)} + ctx := &context.Context{Request: request, Input: input} + ctrlr := Controller{Ctx: ctx} + err := ctrlr.Bind(&s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) +} + +func TestBindNoContentType(t *testing.T) { + var s struct { + Foo string `json:"foo"` + } + header := map[string][]string{} + request := &http.Request{Header: header} + input := &context.BeegoInput{RequestBody: []byte(`{"foo": "FOO"}`)} + ctx := &context.Context{Request: request, Input: input} + ctrlr := Controller{Ctx: ctx} + err := ctrlr.Bind(&s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) +} + +func TestBindXML(t *testing.T) { + var s struct { + Foo string `xml:"foo"` + } + xmlBody := ` + + FOO +` + header := map[string][]string{"Content-Type": {"text/xml"}} + request := &http.Request{Header: header} + input := &context.BeegoInput{RequestBody: []byte(xmlBody)} + ctx := &context.Context{Request: request, Input: input} + ctrlr := Controller{Ctx: ctx} + err := ctrlr.Bind(&s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) +} + +func TestBindYAML(t *testing.T) { + var s struct { + Foo string `yaml:"foo"` + } + header := map[string][]string{"Content-Type": {"application/x-yaml"}} + request := &http.Request{Header: header} + input := &context.BeegoInput{RequestBody: []byte("foo: FOO")} + ctx := &context.Context{Request: request, Input: input} + ctrlr := Controller{Ctx: ctx} + err := ctrlr.Bind(&s) + require.NoError(t, err) + assert.Equal(t, "FOO", s.Foo) +} + +type TestRespController struct { + Controller +} + +func (t *TestRespController) TestResponse() { + type S struct { + Foo string `json:"foo" xml:"foo" yaml:"foo"` + } + + bar := S{Foo: "bar"} + + _ = t.Resp(bar) +} + +func (t *TestRespController) TestSaveToFile() { + err := t.SaveToFile(fileKey, toFile) + if err != nil { + t.Ctx.WriteString("save file fail") + } + err = os.Remove(toFile) + if err != nil { + t.Ctx.WriteString("save file fail") + } + t.Ctx.WriteString("save file success") +} + +type respTestCase struct { + Accept string + ExpectedContentLength int64 + ExpectedResponse string +} + +func TestControllerResp(t *testing.T) { + // test cases + tcs := []respTestCase{ + {Accept: context.ApplicationJSON, ExpectedContentLength: 13, ExpectedResponse: `{"foo":"bar"}`}, + {Accept: context.ApplicationXML, ExpectedContentLength: 21, ExpectedResponse: `bar`}, + {Accept: context.ApplicationYAML, ExpectedContentLength: 9, ExpectedResponse: "foo: bar\n"}, + {Accept: "OTHER", ExpectedContentLength: 13, ExpectedResponse: `{"foo":"bar"}`}, + } + + for _, tc := range tcs { + testControllerRespTestCases(t, tc) + } +} + +func testControllerRespTestCases(t *testing.T, tc respTestCase) { + // create fake GET request + r, _ := http.NewRequest("GET", "/", nil) + r.Header.Set("Accept", tc.Accept) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/", &TestRespController{}, WithRouterMethods(&TestRespController{}, "get:TestResponse")) + handler.ServeHTTP(w, r) + + response := w.Result() + if response.ContentLength != tc.ExpectedContentLength { + t.Errorf("TestResponse() unable to validate content length %d for %s", response.ContentLength, tc.Accept) + } + + if response.StatusCode != http.StatusOK { + t.Errorf("TestResponse() failed to validate response code for %s", tc.Accept) + } + + bodyBytes, err := io.ReadAll(response.Body) + if err != nil { + t.Errorf("TestResponse() failed to parse response body for %s", tc.Accept) + } + bodyString := string(bodyBytes) + if bodyString != tc.ExpectedResponse { + t.Errorf("TestResponse() failed to validate response body '%s' for %s", bodyString, tc.Accept) + } +} + +func createReqBody(filePath string) (string, io.Reader, error) { + var err error + + buf := new(bytes.Buffer) + bw := multipart.NewWriter(buf) // body writer + + f, err := os.Open(filePath) + if err != nil { + return "", nil, err + } + defer func() { + _ = f.Close() + }() + + // text part1 + p1w, _ := bw.CreateFormField("name") + _, err = p1w.Write([]byte("Tony Bai")) + if err != nil { + return "", nil, err + } + + // text part2 + p2w, _ := bw.CreateFormField("age") + _, err = p2w.Write([]byte("15")) + if err != nil { + return "", nil, err + } + + // file part1 + _, fileName := filepath.Split(filePath) + fw1, _ := bw.CreateFormFile(fileKey, fileName) + _, err = io.Copy(fw1, f) + if err != nil { + return "", nil, err + } + + _ = bw.Close() // write the tail boundary + return bw.FormDataContentType(), buf, nil +} + +func TestControllerSaveFile(t *testing.T) { + // create body + contType, bodyReader, err := createReqBody(testFile) + assert.NoError(t, err) + + // create fake POST request + r, _ := http.NewRequest("POST", "/upload_file", bodyReader) + r.Header.Set("Accept", context.ApplicationForm) + r.Header.Set("Content-Type", contType) + w := httptest.NewRecorder() + + // setup the handler + handler := NewControllerRegister() + handler.Add("/upload_file", &TestRespController{}, + WithRouterMethods(&TestRespController{}, "post:TestSaveToFile")) + handler.ServeHTTP(w, r) + + response := w.Result() + bs := make([]byte, 100) + n, err := response.Body.Read(bs) + assert.NoError(t, err) + if string(bs[:n]) == "save file fail" { + t.Errorf("TestSaveToFile() failed to validate response") + } + if response.StatusCode != http.StatusOK { + t.Errorf("TestSaveToFile() failed to validate response code for %s", context.ApplicationJSON) + } +} diff --git a/server/web/doc.go b/server/web/doc.go new file mode 100644 index 0000000000..2ea053e40f --- /dev/null +++ b/server/web/doc.go @@ -0,0 +1,15 @@ +/* +Package beego provide a MVC framework +beego: an open-source, high-performance, modular, full-stack web framework + +It is used for rapid development of RESTful APIs, web apps and backend services in Go. +beego is inspired by Tornado, Sinatra and Flask with the added benefit of some Go-specific features such as interfaces and struct embedding. + + package main + import "github.com/beego/beego/v2" + + func main() { + beego.Run() + } +*/ +package web diff --git a/error.go b/server/web/error.go similarity index 92% rename from error.go rename to server/web/error.go index e5e9fd4742..da35ad8a41 100644 --- a/error.go +++ b/server/web/error.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "fmt" @@ -23,12 +23,13 @@ import ( "strconv" "strings" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" + "github.com/beego/beego/v2" + "github.com/beego/beego/v2/core/utils" + "github.com/beego/beego/v2/server/web/context" ) const ( - errorTypeHandler = iota + errorTypeHandler = iota errorTypeController ) @@ -90,7 +91,7 @@ func showErr(err interface{}, ctx *context.Context, stack string) { "RequestURL": ctx.Input.URI(), "RemoteAddr": ctx.Input.IP(), "Stack": stack, - "BeegoVersion": VERSION, + "BeegoVersion": beego.VERSION, "GoVersion": runtime.Version(), } t.Execute(ctx.ResponseWriter, data) @@ -359,11 +360,25 @@ func gatewayTimeout(rw http.ResponseWriter, r *http.Request) { ) } +// show 413 Payload Too Large +func payloadTooLarge(rw http.ResponseWriter, r *http.Request) { + responseError(rw, r, + 413, + `
The page you have requested is unavailable. +
Perhaps you are here because:

+
    +
    The request entity is larger than limits defined by server. +
    Please change the request entity and try again. +
+ `, + ) +} + func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errContent string) { t, _ := template.New("beegoerrortemp").Parse(errtpl) data := M{ "Title": http.StatusText(errCode), - "BeegoVersion": VERSION, + "BeegoVersion": beego.VERSION, "Content": template.HTML(errContent), } t.Execute(rw, data) @@ -371,9 +386,10 @@ func responseError(rw http.ResponseWriter, r *http.Request, errCode int, errCont // ErrorHandler registers http.HandlerFunc to each http err code string. // usage: -// beego.ErrorHandler("404",NotFound) +// +// beego.ErrorHandler("404",NotFound) // beego.ErrorHandler("500",InternalServerError) -func ErrorHandler(code string, h http.HandlerFunc) *App { +func ErrorHandler(code string, h http.HandlerFunc) *HttpServer { ErrorMaps[code] = &errorInfo{ errorType: errorTypeHandler, handler: h, @@ -384,8 +400,9 @@ func ErrorHandler(code string, h http.HandlerFunc) *App { // ErrorController registers ControllerInterface to each http err code string. // usage: -// beego.ErrorController(&controllers.ErrorController{}) -func ErrorController(c ControllerInterface) *App { +// +// beego.ErrorController(&controllers.ErrorController{}) +func ErrorController(c ControllerInterface) *HttpServer { reflectVal := reflect.ValueOf(c) rt := reflectVal.Type() ct := reflect.Indirect(reflectVal).Type() @@ -428,13 +445,13 @@ func exception(errCode string, ctx *context.Context) { return } } - //if 50x error has been removed from errorMap + // if 50x error has been removed from errorMap ctx.ResponseWriter.WriteHeader(atoi(errCode)) ctx.WriteString(errCode) } func executeError(err *errorInfo, ctx *context.Context, code int) { - //make sure to log the error in the access log + // make sure to log the error in the access log LogAccess(ctx, nil, code) if err.errorType == errorTypeHandler { @@ -444,16 +461,16 @@ func executeError(err *errorInfo, ctx *context.Context, code int) { } if err.errorType == errorTypeController { ctx.Output.SetStatus(code) - //Invoke the request handler + // Invoke the request handler vc := reflect.New(err.controllerType) execController, ok := vc.Interface().(ControllerInterface) if !ok { panic("controller is not ControllerInterface") } - //call the controller init function + // call the controller init function execController.Init(ctx, err.controllerType.Name(), err.method, vc.Interface()) - //call prepare function + // call prepare function execController.Prepare() execController.URLMapping() @@ -461,7 +478,7 @@ func executeError(err *errorInfo, ctx *context.Context, code int) { method := vc.MethodByName(err.method) method.Call([]reflect.Value{}) - //render template + // render template if BConfig.WebConfig.AutoRender { if err := execController.Render(); err != nil { panic(err) diff --git a/error_test.go b/server/web/error_test.go similarity index 99% rename from error_test.go rename to server/web/error_test.go index 378aa9538a..2685a9856b 100644 --- a/error_test.go +++ b/server/web/error_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" diff --git a/server/web/filter.go b/server/web/filter.go new file mode 100644 index 0000000000..8ada5b3297 --- /dev/null +++ b/server/web/filter.go @@ -0,0 +1,136 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "strings" + + "github.com/beego/beego/v2/server/web/context" +) + +// FilterChain is different from pure FilterFunc +// when you use this, you must invoke next(ctx) inside the FilterFunc which is returned +// And all those FilterChain will be invoked before other FilterFunc +type FilterChain func(next FilterFunc) FilterFunc + +// FilterFunc defines a filter function which is invoked before the controller handler is executed. +// It's a alias of HandleFunc +// In fact, the HandleFunc is the last Filter. This is the truth +type FilterFunc = HandleFunc + +// FilterRouter defines a filter operation which is invoked before the controller handler is executed. +// It can match the URL against a pattern, and execute a filter function +// when a request with a matching URL arrives. +type FilterRouter struct { + filterFunc FilterFunc + next *FilterRouter + tree *Tree + pattern string + returnOnOutput bool + resetParams bool +} + +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func newFilterRouter(pattern string, filter FilterFunc, opts ...FilterOpt) *FilterRouter { + mr := &FilterRouter{ + tree: NewTree(), + pattern: pattern, + filterFunc: filter, + } + + fos := &filterOpts{ + returnOnOutput: true, + } + + for _, o := range opts { + o(fos) + } + + if !fos.routerCaseSensitive { + mr.pattern = strings.ToLower(pattern) + } + + mr.returnOnOutput = fos.returnOnOutput + mr.resetParams = fos.resetParams + mr.tree.AddRouter(pattern, true) + return mr +} + +// filter will check whether we need to execute the filter logic +// return (started, done) +func (f *FilterRouter) filter(ctx *context.Context, urlPath string, preFilterParams map[string]string) (bool, bool) { + if f.returnOnOutput && ctx.ResponseWriter.Started { + return true, true + } + if f.resetParams { + preFilterParams = ctx.Input.Params() + } + if ok := f.ValidRouter(urlPath, ctx); ok { + f.filterFunc(ctx) + if f.resetParams { + ctx.Input.ResetParams() + for k, v := range preFilterParams { + ctx.Input.SetParam(k, v) + } + } + } else if f.next != nil { + return f.next.filter(ctx, urlPath, preFilterParams) + } + if f.returnOnOutput && ctx.ResponseWriter.Started { + return true, true + } + return false, false +} + +// ValidRouter checks if the current request is matched by this filter. +// If the request is matched, the values of the URL parameters defined +// by the filter pattern are also returned. +func (f *FilterRouter) ValidRouter(url string, ctx *context.Context) bool { + isOk := f.tree.Match(url, ctx) + if isOk != nil { + if b, ok := isOk.(bool); ok { + return b + } + } + return false +} + +type filterOpts struct { + returnOnOutput bool + resetParams bool + routerCaseSensitive bool +} + +type FilterOpt func(opts *filterOpts) + +func WithReturnOnOutput(ret bool) FilterOpt { + return func(opts *filterOpts) { + opts.returnOnOutput = ret + } +} + +func WithResetParams(reset bool) FilterOpt { + return func(opts *filterOpts) { + opts.resetParams = reset + } +} + +func WithCaseSensitive(sensitive bool) FilterOpt { + return func(opts *filterOpts) { + opts.routerCaseSensitive = sensitive + } +} diff --git a/plugins/apiauth/apiauth.go b/server/web/filter/apiauth/apiauth.go similarity index 69% rename from plugins/apiauth/apiauth.go rename to server/web/filter/apiauth/apiauth.go index 10e25f3f4a..3fa7b2925b 100644 --- a/plugins/apiauth/apiauth.go +++ b/server/web/filter/apiauth/apiauth.go @@ -15,14 +15,15 @@ // Package apiauth provides handlers to enable apiauth support. // // Simple Usage: +// // import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/apiauth" +// "github.com/beego/beego/v2" +// "github.com/beego/beego/v2/server/web/filter/apiauth" // ) // // func main(){ // // apiauth every request -// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBaiscAuth("appid","appkey")) +// beego.InsertFilter("*", beego.BeforeRouter,apiauth.APIBasicAuth("appid","appkey")) // beego.Run() // } // @@ -37,11 +38,11 @@ // // Information: // -// In the request user should include these params in the query +// # In the request user should include these params in the query // // 1. appid // -// appid is assigned to the application +// appid is assigned to the application // // 2. signature // @@ -51,8 +52,7 @@ // // 3. timestamp: // -// send the request time, the format is yyyy-mm-dd HH:ii:ss -// +// send the request time, the format is yyyy-mm-dd HH:ii:ss package apiauth import ( @@ -65,15 +65,15 @@ import ( "sort" "time" - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" ) -// AppIDToAppSecret is used to get appsecret throw appid +// AppIDToAppSecret gets appsecret through appid type AppIDToAppSecret func(string) string -// APIBasicAuth use the basic appid/appkey as the AppIdToAppSecret -func APIBasicAuth(appid, appkey string) beego.FilterFunc { +// APIBasicAuth uses the basic appid/appkey as the AppIdToAppSecret +func APIBasicAuth(appid, appkey string) web.FilterFunc { ft := func(aid string) string { if aid == appid { return appkey @@ -83,59 +83,56 @@ func APIBasicAuth(appid, appkey string) beego.FilterFunc { return APISecretAuth(ft, 300) } -// APIBaiscAuth calls APIBasicAuth for previous callers -func APIBaiscAuth(appid, appkey string) beego.FilterFunc { - return APIBasicAuth(appid, appkey) -} - -// APISecretAuth use AppIdToAppSecret verify and -func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { +// APISecretAuth uses AppIdToAppSecret verify and +func APISecretAuth(f AppIDToAppSecret, timeout int) web.FilterFunc { return func(ctx *context.Context) { if ctx.Input.Query("appid") == "" { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: appid") + ctx.WriteString("missing query parameter: appid") return } appsecret := f(ctx.Input.Query("appid")) if appsecret == "" { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("not exist this appid") + ctx.WriteString("appid query parameter missing") return } if ctx.Input.Query("signature") == "" { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: signature") + ctx.WriteString("missing query parameter: signature") + return } if ctx.Input.Query("timestamp") == "" { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("miss query param: timestamp") + ctx.WriteString("missing query parameter: timestamp") return } u, err := time.Parse("2006-01-02 15:04:05", ctx.Input.Query("timestamp")) if err != nil { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("timestamp format is error, should 2006-01-02 15:04:05") + ctx.WriteString("incorrect timestamp format. Should be in the form 2006-01-02 15:04:05") + return } t := time.Now() if t.Sub(u).Seconds() > float64(timeout) { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("timeout! the request time is long ago, please try again") + ctx.WriteString("request timer timeout exceeded. Please try again") return } if ctx.Input.Query("signature") != Signature(appsecret, ctx.Input.Method(), ctx.Request.Form, ctx.Input.URL()) { ctx.ResponseWriter.WriteHeader(403) - ctx.WriteString("auth failed") + ctx.WriteString("authentication failed") } } } -// Signature used to generate signature with the appsecret/method/params/RequestURI +// Signature generates signature with appsecret/method/params/RequestURI func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { var b bytes.Buffer - keys := make([]string, len(params)) + keys := make([]string, 0, len(params)) pa := make(map[string]string) for k, v := range params { pa[k] = v[0] diff --git a/plugins/apiauth/apiauth_test.go b/server/web/filter/apiauth/apiauth_test.go similarity index 100% rename from plugins/apiauth/apiauth_test.go rename to server/web/filter/apiauth/apiauth_test.go diff --git a/plugins/auth/basic.go b/server/web/filter/auth/basic.go similarity index 91% rename from plugins/auth/basic.go rename to server/web/filter/auth/basic.go index c478044abb..2881d7cf40 100644 --- a/plugins/auth/basic.go +++ b/server/web/filter/auth/basic.go @@ -14,9 +14,10 @@ // Package auth provides handlers to enable basic auth support. // Simple Usage: +// // import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/auth" +// "github.com/beego/beego/v2" +// "github.com/beego/beego/v2/server/web/filter/auth" // ) // // func main(){ @@ -25,7 +26,6 @@ // beego.Run() // } // -// // Advanced Usage: // // func SecretAuth(username, password string) bool { @@ -40,14 +40,14 @@ import ( "net/http" "strings" - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" ) var defaultRealm = "Authorization Required" // Basic is the http basic auth -func Basic(username string, password string) beego.FilterFunc { +func Basic(username string, password string) web.FilterFunc { secrets := func(user, pass string) bool { return user == username && pass == password } @@ -55,7 +55,7 @@ func Basic(username string, password string) beego.FilterFunc { } // NewBasicAuthenticator return the BasicAuth -func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { +func NewBasicAuthenticator(secrets SecretProvider, Realm string) web.FilterFunc { return func(ctx *context.Context) { a := &BasicAuth{Secrets: secrets, Realm: Realm} if username := a.CheckAuth(ctx.Request); username == "" { diff --git a/plugins/authz/authz.go b/server/web/filter/authz/authz.go similarity index 91% rename from plugins/authz/authz.go rename to server/web/filter/authz/authz.go index 9dc0db76eb..7475ea6c23 100644 --- a/plugins/authz/authz.go +++ b/server/web/filter/authz/authz.go @@ -14,9 +14,10 @@ // Package authz provides handlers to enable ACL, RBAC, ABAC authorization support. // Simple Usage: +// // import( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/authz" +// "github.com/beego/beego/v2" +// "github.com/beego/beego/v2/server/web/filter/authz" // "github.com/casbin/casbin" // ) // @@ -26,7 +27,6 @@ // beego.Run() // } // -// // Advanced Usage: // // func main(){ @@ -40,15 +40,17 @@ package authz import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" - "github.com/casbin/casbin" "net/http" + + "github.com/casbin/casbin" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" ) // NewAuthorizer returns the authorizer. // Use a casbin enforcer as input -func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc { +func NewAuthorizer(e *casbin.Enforcer) web.FilterFunc { return func(ctx *context.Context) { a := &BasicAuthorizer{enforcer: e} diff --git a/plugins/authz/authz_model.conf b/server/web/filter/authz/authz_model.conf similarity index 92% rename from plugins/authz/authz_model.conf rename to server/web/filter/authz/authz_model.conf index d1b3dbd7aa..fd2f08df6d 100644 --- a/plugins/authz/authz_model.conf +++ b/server/web/filter/authz/authz_model.conf @@ -11,4 +11,4 @@ g = _, _ e = some(where (p.eft == allow)) [matchers] -m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") \ No newline at end of file +m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") diff --git a/plugins/authz/authz_policy.csv b/server/web/filter/authz/authz_policy.csv similarity index 88% rename from plugins/authz/authz_policy.csv rename to server/web/filter/authz/authz_policy.csv index c062dd3e28..9203e11f83 100644 --- a/plugins/authz/authz_policy.csv +++ b/server/web/filter/authz/authz_policy.csv @@ -4,4 +4,4 @@ p, bob, /dataset2/resource1, * p, bob, /dataset2/resource2, GET p, bob, /dataset2/folder1/*, POST p, dataset1_admin, /dataset1/*, * -g, cathy, dataset1_admin \ No newline at end of file +g, cathy, dataset1_admin diff --git a/plugins/authz/authz_test.go b/server/web/filter/authz/authz_test.go similarity index 79% rename from plugins/authz/authz_test.go rename to server/web/filter/authz/authz_test.go index 49aed84cec..3715395465 100644 --- a/plugins/authz/authz_test.go +++ b/server/web/filter/authz/authz_test.go @@ -15,16 +15,18 @@ package authz import ( - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/plugins/auth" - "github.com/casbin/casbin" "net/http" "net/http/httptest" "testing" + + "github.com/casbin/casbin" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/filter/auth" ) -func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) { +func testRequest(t *testing.T, handler *web.ControllerRegister, user string, path string, method string, code int) { r, _ := http.NewRequest(method, path, nil) r.SetBasicAuth(user, "123") w := httptest.NewRecorder() @@ -36,10 +38,10 @@ func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, p } func TestBasic(t *testing.T) { - handler := beego.NewControllerRegister() + handler := web.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + handler.InsertFilter("*", web.BeforeRouter, auth.Basic("alice", "123")) + handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) @@ -52,10 +54,10 @@ func TestBasic(t *testing.T) { } func TestPathWildcard(t *testing.T) { - handler := beego.NewControllerRegister() + handler := web.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123")) - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) + handler.InsertFilter("*", web.BeforeRouter, auth.Basic("bob", "123")) + handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv"))) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) @@ -77,11 +79,11 @@ func TestPathWildcard(t *testing.T) { } func TestRBAC(t *testing.T) { - handler := beego.NewControllerRegister() + handler := web.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123")) + handler.InsertFilter("*", web.BeforeRouter, auth.Basic("cathy", "123")) e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv") - handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e)) + handler.InsertFilter("*", web.BeforeRouter, NewAuthorizer(e)) handler.Any("*", func(ctx *context.Context) { ctx.Output.SetStatus(200) diff --git a/plugins/cors/cors.go b/server/web/filter/cors/cors.go similarity index 95% rename from plugins/cors/cors.go rename to server/web/filter/cors/cors.go index 45c327ab46..4c5cb5b71e 100644 --- a/plugins/cors/cors.go +++ b/server/web/filter/cors/cors.go @@ -14,9 +14,11 @@ // Package cors provides handlers to enable CORS support. // Usage +// // import ( -// "github.com/astaxie/beego" -// "github.com/astaxie/beego/plugins/cors" +// "github.com/beego/beego/v2" +// "github.com/beego/beego/v2/server/web/filter/cors" +// // ) // // func main() { @@ -42,8 +44,8 @@ import ( "strings" "time" - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" ) const ( @@ -69,10 +71,10 @@ var ( type Options struct { // If set, all origins are allowed. AllowAllOrigins bool - // A list of allowed origins. Wild cards and FQDNs are supported. - AllowOrigins []string // If set, allows to share auth credentials such as cookies. AllowCredentials bool + // A list of allowed origins. Wild cards and FQDNs are supported. + AllowOrigins []string // A list of allowed HTTP methods. AllowMethods []string // A list of allowed HTTP headers. @@ -143,7 +145,7 @@ func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map rHeader = strings.TrimSpace(rHeader) lookupLoop: for _, allowedHeader := range o.AllowHeaders { - if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { + if strings.EqualFold(rHeader, allowedHeader) { allowed = append(allowed, rHeader) break lookupLoop } @@ -187,7 +189,7 @@ func (o *Options) IsOriginAllowed(origin string) (allowed bool) { } // Allow enables CORS for requests those match the provided options. -func Allow(opts *Options) beego.FilterFunc { +func Allow(opts *Options) web.FilterFunc { // Allow default headers if nothing is specified. if len(opts.AllowHeaders) == 0 { opts.AllowHeaders = defaultAllowHeaders diff --git a/plugins/cors/cors_test.go b/server/web/filter/cors/cors_test.go similarity index 88% rename from plugins/cors/cors_test.go rename to server/web/filter/cors/cors_test.go index 3403914353..e907a20214 100644 --- a/plugins/cors/cors_test.go +++ b/server/web/filter/cors/cors_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" - "github.com/astaxie/beego" - "github.com/astaxie/beego/context" + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" ) // HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header @@ -55,8 +55,8 @@ func (gr *HTTPHeaderGuardRecorder) Header() http.Header { func Test_AllowAll(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, })) handler.Any("/foo", func(ctx *context.Context) { @@ -72,8 +72,8 @@ func Test_AllowAll(t *testing.T) { func Test_AllowRegexMatch(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, })) handler.Any("/foo", func(ctx *context.Context) { @@ -92,8 +92,8 @@ func Test_AllowRegexMatch(t *testing.T) { func Test_AllowRegexNoMatch(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowOrigins: []string{"https://*.foo.com"}, })) handler.Any("/foo", func(ctx *context.Context) { @@ -112,8 +112,8 @@ func Test_AllowRegexNoMatch(t *testing.T) { func Test_OtherHeaders(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, AllowCredentials: true, AllowMethods: []string{"PATCH", "GET"}, @@ -156,8 +156,8 @@ func Test_OtherHeaders(t *testing.T) { func Test_DefaultAllowHeaders(t *testing.T) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, })) handler.Any("/foo", func(ctx *context.Context) { @@ -175,8 +175,8 @@ func Test_DefaultAllowHeaders(t *testing.T) { func Test_Preflight(t *testing.T) { recorder := NewRecorder() - handler := beego.NewControllerRegister() - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, AllowMethods: []string{"PUT", "PATCH"}, AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, @@ -219,8 +219,8 @@ func Test_Preflight(t *testing.T) { func Benchmark_WithoutCORS(b *testing.B) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD handler.Any("/foo", func(ctx *context.Context) { ctx.Output.SetStatus(500) }) @@ -233,9 +233,9 @@ func Benchmark_WithoutCORS(b *testing.B) { func Benchmark_WithCORS(b *testing.B) { recorder := httptest.NewRecorder() - handler := beego.NewControllerRegister() - beego.BConfig.RunMode = beego.PROD - handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + handler.InsertFilter("*", web.BeforeRouter, Allow(&Options{ AllowAllOrigins: true, AllowCredentials: true, AllowMethods: []string{"PATCH", "GET"}, diff --git a/server/web/filter/opentracing/filter.go b/server/web/filter/opentracing/filter.go new file mode 100644 index 0000000000..552c7a4cb5 --- /dev/null +++ b/server/web/filter/opentracing/filter.go @@ -0,0 +1,87 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "context" + + opentracingKit "github.com/go-kit/kit/tracing/opentracing" + logKit "github.com/go-kit/log" + "github.com/opentracing/opentracing-go" + + "github.com/beego/beego/v2/server/web" + beegoCtx "github.com/beego/beego/v2/server/web/context" +) + +// FilterChainBuilder provides an extension point that we can support more configurations if necessary +type FilterChainBuilder struct { + // CustomSpanFunc makes users to custom the span. + CustomSpanFunc func(span opentracing.Span, ctx *beegoCtx.Context) +} + +func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc { + return func(ctx *beegoCtx.Context) { + var ( + spanCtx context.Context + span opentracing.Span + ) + operationName := builder.operationName(ctx) + + if preSpan := opentracing.SpanFromContext(ctx.Request.Context()); preSpan == nil { + inject := opentracingKit.HTTPToContext(opentracing.GlobalTracer(), operationName, logKit.NewNopLogger()) + spanCtx = inject(ctx.Request.Context(), ctx.Request) + span = opentracing.SpanFromContext(spanCtx) + } else { + span, spanCtx = opentracing.StartSpanFromContext(ctx.Request.Context(), operationName) + } + + defer span.Finish() + + newReq := ctx.Request.Clone(spanCtx) + ctx.Reset(ctx.ResponseWriter.ResponseWriter, newReq) + + next(ctx) + // if you think we need to do more things, feel free to create an issue to tell us + span.SetTag("http.status_code", ctx.ResponseWriter.Status) + span.SetTag("http.method", ctx.Input.Method()) + span.SetTag("peer.hostname", ctx.Request.Host) + span.SetTag("http.url", ctx.Request.URL.String()) + span.SetTag("http.scheme", ctx.Request.URL.Scheme) + span.SetTag("span.kind", "server") + span.SetTag("component", "beego") + if ctx.Output.IsServerError() || ctx.Output.IsClientError() { + span.SetTag("error", true) + } + span.SetTag("peer.address", ctx.Request.RemoteAddr) + span.SetTag("http.proto", ctx.Request.Proto) + + span.SetTag("beego.route", ctx.Input.GetData("RouterPattern")) + + if builder.CustomSpanFunc != nil { + builder.CustomSpanFunc(span, ctx) + } + } +} + +func (builder *FilterChainBuilder) operationName(ctx *beegoCtx.Context) string { + operationName := ctx.Input.URL() + // it means that there is not any span, so we create a span as the root span. + // TODO, if we support multiple servers, this need to be changed + route, found := web.BeeApp.Handlers.FindRouter(ctx) + if found { + operationName = ctx.Input.Method() + "#" + route.GetPattern() + } + return operationName +} diff --git a/server/web/filter/opentracing/filter_test.go b/server/web/filter/opentracing/filter_test.go new file mode 100644 index 0000000000..d92c98a144 --- /dev/null +++ b/server/web/filter/opentracing/filter_test.go @@ -0,0 +1,47 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package opentracing + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/opentracing/opentracing-go" + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web/context" +) + +func TestFilterChainBuilder_FilterChain(t *testing.T) { + builder := &FilterChainBuilder{ + CustomSpanFunc: func(span opentracing.Span, ctx *context.Context) { + span.SetTag("aa", "bbb") + }, + } + + ctx := context.NewContext() + r, _ := http.NewRequest("GET", "/prometheus/user", nil) + w := httptest.NewRecorder() + ctx.Reset(w, r) + ctx.Input.SetData("RouterPattern", "my-route") + + filterFunc := builder.FilterChain(func(ctx *context.Context) { + ctx.Input.SetData("opentracing", true) + }) + + filterFunc(ctx) + assert.True(t, ctx.Input.GetData("opentracing").(bool)) +} diff --git a/server/web/filter/prometheus/filter.go b/server/web/filter/prometheus/filter.go new file mode 100644 index 0000000000..2a26e2f118 --- /dev/null +++ b/server/web/filter/prometheus/filter.go @@ -0,0 +1,107 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "strconv" + "strings" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/beego/beego/v2" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +const unknownRouterPattern = "UnknownRouterPattern" + +// FilterChainBuilder is an extension point, +// when we want to support some configuration, +// please use this structure +type FilterChainBuilder struct{} + +var ( + summaryVec prometheus.ObserverVec + initSummaryVec sync.Once +) + +// FilterChain returns a FilterFunc. The filter will records some metrics +func (builder *FilterChainBuilder) FilterChain(next web.FilterFunc) web.FilterFunc { + initSummaryVec.Do(func() { + summaryVec = builder.buildVec() + err := prometheus.Register(summaryVec) + if _, ok := err.(*prometheus.AlreadyRegisteredError); err != nil && !ok { + logs.Error("web module register prometheus vector failed, %+v", err) + } + registerBuildInfo() + }) + + return func(ctx *context.Context) { + startTime := time.Now() + next(ctx) + endTime := time.Now() + go report(endTime.Sub(startTime), ctx, summaryVec) + } +} + +func (builder *FilterChainBuilder) buildVec() *prometheus.SummaryVec { + summaryVec := prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "beego", + Subsystem: "http_request", + ConstLabels: map[string]string{ + "server": web.BConfig.ServerName, + "env": web.BConfig.RunMode, + "appname": web.BConfig.AppName, + }, + Help: "The statics info for http request", + }, []string{"pattern", "method", "status"}) + return summaryVec +} + +func registerBuildInfo() { + buildInfo := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "beego", + Subsystem: "build_info", + Help: "The building information", + ConstLabels: map[string]string{ + "appname": web.BConfig.AppName, + "build_version": beego.BuildVersion, + "build_revision": beego.BuildGitRevision, + "build_status": beego.BuildStatus, + "build_tag": beego.BuildTag, + "build_time": strings.Replace(beego.BuildTime, "--", " ", 1), + "go_version": beego.GoVersion, + "git_branch": beego.GitBranch, + "start_time": time.Now().Format("2006-01-02 15:04:05"), + }, + }, []string{}) + + _ = prometheus.Register(buildInfo) + buildInfo.WithLabelValues().Set(1) +} + +func report(dur time.Duration, ctx *context.Context, vec prometheus.ObserverVec) { + status := ctx.Output.Status + ptnItf := ctx.Input.GetData("RouterPattern") + ptn := unknownRouterPattern + if ptnItf != nil { + ptn = ptnItf.(string) + } + ms := dur / time.Millisecond + vec.WithLabelValues(ptn, ctx.Input.Method(), strconv.Itoa(status)).Observe(float64(ms)) +} diff --git a/server/web/filter/prometheus/filter_test.go b/server/web/filter/prometheus/filter_test.go new file mode 100644 index 0000000000..46ba3caf4d --- /dev/null +++ b/server/web/filter/prometheus/filter_test.go @@ -0,0 +1,55 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web/context" +) + +func TestFilterChain(t *testing.T) { + filter := (&FilterChainBuilder{}).FilterChain(func(ctx *context.Context) { + // do nothing + ctx.Input.SetData("invocation", true) + }) + + ctx := context.NewContext() + r, _ := http.NewRequest("GET", "/prometheus/user", nil) + w := httptest.NewRecorder() + ctx.Reset(w, r) + ctx.Input.SetData("RouterPattern", "my-route") + filter(ctx) + assert.True(t, ctx.Input.GetData("invocation").(bool)) + time.Sleep(1 * time.Second) +} + +func TestFilterChainBuilder_report(t *testing.T) { + ctx := context.NewContext() + r, _ := http.NewRequest("GET", "/prometheus/user", nil) + w := httptest.NewRecorder() + ctx.Reset(w, r) + fb := &FilterChainBuilder{} + // without router info + report(time.Second, ctx, fb.buildVec()) + + ctx.Input.SetData("RouterPattern", "my-route") + report(time.Second, ctx, fb.buildVec()) +} diff --git a/server/web/filter/ratelimit/bucket.go b/server/web/filter/ratelimit/bucket.go new file mode 100644 index 0000000000..67a3907e1f --- /dev/null +++ b/server/web/filter/ratelimit/bucket.go @@ -0,0 +1,14 @@ +package ratelimit + +import "time" + +// bucket is an interface store ratelimit info +type bucket interface { + take(amount uint) bool + getCapacity() uint + getRemaining() uint + getRate() time.Duration +} + +// bucketOption is constructor option +type bucketOption func(bucket) diff --git a/server/web/filter/ratelimit/limiter.go b/server/web/filter/ratelimit/limiter.go new file mode 100644 index 0000000000..17ef46e3f7 --- /dev/null +++ b/server/web/filter/ratelimit/limiter.go @@ -0,0 +1,163 @@ +// Copyright 2020 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "sync" + "time" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +// limiterOption is constructor option +type limiterOption func(l *limiter) + +type limiter struct { + sync.RWMutex + capacity uint + rate time.Duration + buckets map[string]bucket + bucketFactory func(opts ...bucketOption) bucket + sessionKey func(ctx *context.Context) string + resp RejectionResponse +} + +// RejectionResponse stores response information +// for the request rejected by limiter +type RejectionResponse struct { + code int + body string +} + +const perRequestConsumedAmount = 1 + +var defaultRejectionResponse = RejectionResponse{ + code: 429, + body: "too many requests", +} + +// NewLimiter return FilterFunc, the limiter enables rate limit +// according to the configuration. +func NewLimiter(opts ...limiterOption) web.FilterFunc { + l := &limiter{ + buckets: make(map[string]bucket), + sessionKey: defaultSessionKey, + rate: time.Millisecond * 10, + capacity: 100, + bucketFactory: newTokenBucket, + resp: defaultRejectionResponse, + } + for _, o := range opts { + o(l) + } + + return func(ctx *context.Context) { + if !l.take(perRequestConsumedAmount, ctx) { + ctx.ResponseWriter.WriteHeader(l.resp.code) + ctx.WriteString(l.resp.body) + } + } +} + +// WithSessionKey return limiterOption. WithSessionKey config func +// which defines the request characteristic against the limit is applied +func WithSessionKey(f func(ctx *context.Context) string) limiterOption { + return func(l *limiter) { + l.sessionKey = f + } +} + +// WithRate return limiterOption. WithRate config how long it takes to +// generate a token. +func WithRate(r time.Duration) limiterOption { + return func(l *limiter) { + l.rate = r + } +} + +// WithCapacity return limiterOption. WithCapacity config the capacity size. +// The bucket with a capacity of n has n tokens after initialization. The capacity +// defines how many requests a client can make in excess of the rate. +func WithCapacity(c uint) limiterOption { + return func(l *limiter) { + l.capacity = c + } +} + +// WithBucketFactory return limiterOption. WithBucketFactory customize the +// implementation of Bucket. +func WithBucketFactory(f func(opts ...bucketOption) bucket) limiterOption { + return func(l *limiter) { + l.bucketFactory = f + } +} + +// WithRejectionResponse return limiterOption. WithRejectionResponse +// customize the response for the request rejected by the limiter. +func WithRejectionResponse(resp RejectionResponse) limiterOption { + return func(l *limiter) { + l.resp = resp + } +} + +func (l *limiter) take(amount uint, ctx *context.Context) bool { + bucket := l.getBucket(ctx) + if bucket == nil { + return true + } + return bucket.take(amount) +} + +func (l *limiter) getBucket(ctx *context.Context) bucket { + key := l.sessionKey(ctx) + l.RLock() + b, ok := l.buckets[key] + l.RUnlock() + if !ok { + b = l.createBucket(key) + } + + return b +} + +func (l *limiter) createBucket(key string) bucket { + l.Lock() + defer l.Unlock() + // double check avoid overwriting + b, ok := l.buckets[key] + if ok { + return b + } + b = l.bucketFactory(withCapacity(l.capacity), withRate(l.rate)) + l.buckets[key] = b + return b +} + +func defaultSessionKey(ctx *context.Context) string { + return "BEEGO_ALL" +} + +func RemoteIPSessionKey(ctx *context.Context) string { + r := ctx.Request + IPAddress := r.Header.Get("X-Real-Ip") + if IPAddress == "" { + IPAddress = r.Header.Get("X-Forwarded-For") + } + if IPAddress == "" { + IPAddress = r.RemoteAddr + } + return IPAddress +} diff --git a/server/web/filter/ratelimit/limiter_test.go b/server/web/filter/ratelimit/limiter_test.go new file mode 100644 index 0000000000..cafede5ead --- /dev/null +++ b/server/web/filter/ratelimit/limiter_test.go @@ -0,0 +1,76 @@ +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/context" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, requestIP, method, path string, code int) { + r, _ := http.NewRequest(method, path, nil) + r.Header.Set("X-Real-Ip", requestIP) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + if w.Code != code { + t.Errorf("%s, %s, %s: %d, supposed to be %d", requestIP, method, path, w.Code, code) + } +} + +func TestLimiter(t *testing.T) { + handler := web.NewControllerRegister() + err := handler.InsertFilter("/foo/*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(1), WithSessionKey(RemoteIPSessionKey))) + if err != nil { + t.Error(err) + } + handler.Any("*", func(ctx *context.Context) { + ctx.Output.SetStatus(200) + }) + + route := "/foo/1" + ip := "127.0.0.1" + testRequest(t, handler, ip, "GET", route, 200) + testRequest(t, handler, ip, "GET", route, 429) + testRequest(t, handler, "127.0.0.2", "GET", route, 200) + time.Sleep(1 * time.Millisecond) + testRequest(t, handler, ip, "GET", route, 200) +} + +func BenchmarkWithoutLimiter(b *testing.B) { + recorder := httptest.NewRecorder() + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + handler.ServeHTTP(recorder, r) + } + }) +} + +func BenchmarkWithLimiter(b *testing.B) { + recorder := httptest.NewRecorder() + handler := web.NewControllerRegister() + web.BConfig.RunMode = web.PROD + err := handler.InsertFilter("*", web.BeforeRouter, NewLimiter(WithRate(1*time.Millisecond), WithCapacity(100))) + if err != nil { + b.Error(err) + } + handler.Any("/foo", func(ctx *context.Context) { + ctx.Output.SetStatus(500) + }) + b.ResetTimer() + r, _ := http.NewRequest("PUT", "/foo", nil) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + handler.ServeHTTP(recorder, r) + } + }) +} diff --git a/server/web/filter/ratelimit/token_bucket.go b/server/web/filter/ratelimit/token_bucket.go new file mode 100644 index 0000000000..da7bb7fcda --- /dev/null +++ b/server/web/filter/ratelimit/token_bucket.go @@ -0,0 +1,76 @@ +package ratelimit + +import ( + "sync" + "time" +) + +type tokenBucket struct { + sync.RWMutex + remaining uint + capacity uint + lastCheckAt time.Time + rate time.Duration +} + +// newTokenBucket return an bucket that implements token bucket +func newTokenBucket(opts ...bucketOption) bucket { + b := &tokenBucket{lastCheckAt: time.Now()} + for _, o := range opts { + o(b) + } + return b +} + +func withCapacity(capacity uint) bucketOption { + return func(b bucket) { + bucket := b.(*tokenBucket) + bucket.capacity = capacity + bucket.remaining = capacity + } +} + +func withRate(rate time.Duration) bucketOption { + return func(b bucket) { + bucket := b.(*tokenBucket) + bucket.rate = rate + } +} + +func (b *tokenBucket) getRemaining() uint { + b.RLock() + defer b.RUnlock() + return b.remaining +} + +func (b *tokenBucket) getRate() time.Duration { + b.RLock() + defer b.RUnlock() + return b.rate +} + +func (b *tokenBucket) getCapacity() uint { + b.RLock() + defer b.RUnlock() + return b.capacity +} + +func (b *tokenBucket) take(amount uint) bool { + if b.rate <= 0 { + return true + } + b.Lock() + defer b.Unlock() + now := time.Now() + times := uint(now.Sub(b.lastCheckAt) / b.rate) + b.lastCheckAt = b.lastCheckAt.Add(time.Duration(times) * b.rate) + b.remaining += times + if b.remaining < amount { + return false + } + b.remaining -= amount + if b.remaining > b.capacity { + b.remaining = b.capacity + } + return true +} diff --git a/server/web/filter/ratelimit/token_bucket_test.go b/server/web/filter/ratelimit/token_bucket_test.go new file mode 100644 index 0000000000..91088eb6da --- /dev/null +++ b/server/web/filter/ratelimit/token_bucket_test.go @@ -0,0 +1,32 @@ +package ratelimit + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetRate(t *testing.T) { + b := newTokenBucket(withRate(1 * time.Second)).(*tokenBucket) + assert.Equal(t, b.getRate(), 1*time.Second) +} + +func TestGetRemainingAndCapacity(t *testing.T) { + b := newTokenBucket(withCapacity(10)) + assert.Equal(t, b.getRemaining(), uint(10)) + assert.Equal(t, b.getCapacity(), uint(10)) +} + +func TestTake(t *testing.T) { + b := newTokenBucket(withCapacity(10), withRate(10*time.Millisecond)).(*tokenBucket) + for i := 0; i < 10; i++ { + assert.True(t, b.take(1)) + } + assert.False(t, b.take(1)) + assert.Equal(t, b.getRemaining(), uint(0)) + b = newTokenBucket(withCapacity(1), withRate(1*time.Millisecond)).(*tokenBucket) + assert.True(t, b.take(1)) + time.Sleep(2 * time.Millisecond) + assert.True(t, b.take(1)) +} diff --git a/server/web/filter/session/filter.go b/server/web/filter/session/filter.go new file mode 100644 index 0000000000..b3bb99c54d --- /dev/null +++ b/server/web/filter/session/filter.go @@ -0,0 +1,37 @@ +package session + +import ( + "context" + + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" +) + +// Session maintain session for web service +// Session new a session storage and store it into webContext.Context +// experimental feature, we may change this in the future +func Session(providerType session.ProviderType, options ...session.ManagerConfigOpt) web.FilterChain { + sessionConfig := session.NewManagerConfig(options...) + sessionManager, _ := session.NewManager(string(providerType), sessionConfig) + go sessionManager.GC() + + return func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if ctx.Input.CruSession != nil { + return + } + + if sess, err := sessionManager.SessionStart(ctx.ResponseWriter, ctx.Request); err != nil { + logs.Error(`init session error:%s`, err.Error()) + } else { + // release session at the end of request + defer sess.SessionRelease(context.Background(), ctx.ResponseWriter) + ctx.Input.CruSession = sess + } + + next(ctx) + } + } +} diff --git a/server/web/filter/session/filter_test.go b/server/web/filter/session/filter_test.go new file mode 100644 index 0000000000..6b41396269 --- /dev/null +++ b/server/web/filter/session/filter_test.go @@ -0,0 +1,88 @@ +package session + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + + "github.com/beego/beego/v2/server/web" + webContext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" +) + +func testRequest(t *testing.T, handler *web.ControllerRegister, path string, method string, code int) { + r, _ := http.NewRequest(method, path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + if w.Code != code { + t.Errorf("%s, %s: %d, supposed to be %d", path, method, w.Code, code) + } +} + +func TestSession(t *testing.T) { + storeKey := uuid.New().String() + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store := ctx.Input.GetData(storeKey); store == nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} + +func TestSession1(t *testing.T) { + handler := web.NewControllerRegister() + handler.InsertFilterChain( + "*", + Session( + session.ProviderMemory, + session.CfgCookieName(`go_session_id`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + ), + ) + handler.InsertFilterChain( + "*", + func(next web.FilterFunc) web.FilterFunc { + return func(ctx *webContext.Context) { + if store, err := ctx.Session(); store == nil || err != nil { + t.Error(`store should not be nil`) + } + next(ctx) + } + }, + ) + handler.Any("*", func(ctx *webContext.Context) { + ctx.Output.SetStatus(200) + }) + + testRequest(t, handler, "/dataset1/resource1", "GET", 200) +} diff --git a/server/web/filter_chain_test.go b/server/web/filter_chain_test.go new file mode 100644 index 0000000000..d3faf51659 --- /dev/null +++ b/server/web/filter_chain_test.go @@ -0,0 +1,136 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web/context" +) + +func TestControllerRegisterInsertFilterChain(t *testing.T) { + InsertFilterChain("/*", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("filter", "filter-chain") + next(ctx) + } + }) + + ns := NewNamespace("/chain") + + ns.Get("/*", func(ctx *context.Context) { + _ = ctx.Output.Body([]byte("hello")) + }) + + r, _ := http.NewRequest("GET", "/chain/user", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.Init() + BeeApp.Handlers.ServeHTTP(w, r) + + assert.Equal(t, "filter-chain", w.Header().Get("filter")) +} + +func TestControllerRegister_InsertFilterChain_Order(t *testing.T) { + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("first", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + InsertFilterChain("/abc", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("second", fmt.Sprintf("%d", time.Now().UnixNano())) + time.Sleep(time.Millisecond * 10) + next(ctx) + } + }) + + r, _ := http.NewRequest("GET", "/abc", nil) + w := httptest.NewRecorder() + + BeeApp.Handlers.Init() + BeeApp.Handlers.ServeHTTP(w, r) + first := w.Header().Get("first") + second := w.Header().Get("second") + + ft, _ := strconv.ParseInt(first, 10, 64) + st, _ := strconv.ParseInt(second, 10, 64) + + assert.True(t, st > ft) +} + +func TestFilterChainRouter(t *testing.T) { + app := NewHttpSever() + const filterNonMatch = "filter-chain-non-match" + app.InsertFilterChain("/app/nonMatch/before/*", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("filter", filterNonMatch) + next(ctx) + } + }) + + const filterAll = "filter-chain-all" + app.InsertFilterChain("/*", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("filter", filterAll) + next(ctx) + } + }) + + app.InsertFilterChain("/app/nonMatch/after/*", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("filter", filterNonMatch) + next(ctx) + } + }) + + app.InsertFilterChain("/app/match/*", func(next FilterFunc) FilterFunc { + return func(ctx *context.Context) { + ctx.Output.Header("match", "yes") + next(ctx) + } + }) + + app.Handlers.Init() + + r, _ := http.NewRequest("GET", "/app/match", nil) + w := httptest.NewRecorder() + + app.Handlers.ServeHTTP(w, r) + assert.Equal(t, filterAll, w.Header().Get("filter")) + assert.Equal(t, "yes", w.Header().Get("match")) + + r, _ = http.NewRequest("GET", "/app/match1", nil) + w = httptest.NewRecorder() + app.Handlers.ServeHTTP(w, r) + assert.Equal(t, filterAll, w.Header().Get("filter")) + assert.NotEqual(t, "yes", w.Header().Get("match")) + + r, _ = http.NewRequest("GET", "/app/nonMatch", nil) + w = httptest.NewRecorder() + app.Handlers.ServeHTTP(w, r) + assert.Equal(t, filterAll, w.Header().Get("filter")) + assert.NotEqual(t, "yes", w.Header().Get("match")) +} diff --git a/filter_test.go b/server/web/filter_test.go similarity index 97% rename from filter_test.go rename to server/web/filter_test.go index 4ca4d2b848..8765243cbb 100644 --- a/filter_test.go +++ b/server/web/filter_test.go @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" "net/http/httptest" "testing" - "github.com/astaxie/beego/context" + "github.com/beego/beego/v2/server/web/context" ) var FilterUser = func(ctx *context.Context) { diff --git a/flash.go b/server/web/flash.go similarity index 98% rename from flash.go rename to server/web/flash.go index a6485a17e2..4b6567ac03 100644 --- a/flash.go +++ b/server/web/flash.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "fmt" @@ -102,7 +102,7 @@ func ReadFromRequest(c *Controller) *FlashData { } } } - //read one time then delete it + // read one time then delete it c.Ctx.SetCookie(BConfig.WebConfig.FlashName, "", -1, "/") } c.Data["flash"] = flash.Data diff --git a/flash_test.go b/server/web/flash_test.go similarity index 92% rename from flash_test.go rename to server/web/flash_test.go index d5e9608dc9..3e20c8fb12 100644 --- a/flash_test.go +++ b/server/web/flash_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" @@ -40,7 +40,7 @@ func TestFlashHeader(t *testing.T) { // setup the handler handler := NewControllerRegister() - handler.Add("/", &TestFlashController{}, "get:TestWriteFlash") + handler.Add("/", &TestFlashController{}, WithRouterMethods(&TestFlashController{}, "get:TestWriteFlash")) handler.ServeHTTP(w, r) // get the Set-Cookie value diff --git a/fs.go b/server/web/fs.go similarity index 97% rename from fs.go rename to server/web/fs.go index 41cc6f6e0d..4d65630a32 100644 --- a/fs.go +++ b/server/web/fs.go @@ -1,4 +1,4 @@ -package beego +package web import ( "net/http" @@ -6,8 +6,7 @@ import ( "path/filepath" ) -type FileSystem struct { -} +type FileSystem struct{} func (d FileSystem) Open(name string) (http.File, error) { return os.Open(name) @@ -17,7 +16,6 @@ func (d FileSystem) Open(name string) (http.File, error) { // directory in the tree, including root. All errors that arise visiting files // and directories are filtered by walkFn. func Walk(fs http.FileSystem, root string, walkFn filepath.WalkFunc) error { - f, err := fs.Open(root) if err != nil { return err diff --git a/grace/grace.go b/server/web/grace/grace.go similarity index 80% rename from grace/grace.go rename to server/web/grace/grace.go index fb0cb7bb69..cefc890bb5 100644 --- a/grace/grace.go +++ b/server/web/grace/grace.go @@ -18,28 +18,30 @@ // Usage: // // import( -// "log" -// "net/http" -// "os" // -// "github.com/astaxie/beego/grace" +// "log" +// "net/http" +// "os" +// +// "github.com/beego/beego/v2/server/web/grace" +// // ) // -// func handler(w http.ResponseWriter, r *http.Request) { -// w.Write([]byte("WORLD!")) -// } +// func handler(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("WORLD!")) +// } // -// func main() { -// mux := http.NewServeMux() -// mux.HandleFunc("/hello", handler) +// func main() { +// mux := http.NewServeMux() +// mux.HandleFunc("/hello", handler) // -// err := grace.ListenAndServe("localhost:8080", mux) -// if err != nil { -// log.Println(err) -// } -// log.Println("Server on 8080 stopped") -// os.Exit(0) -// } +// err := grace.ListenAndServe("localhost:8080", mux) +// if err != nil { +// log.Println(err) +// } +// log.Println("Server on 8080 stopped") +// os.Exit(0) +// } package grace import ( @@ -105,8 +107,17 @@ func init() { } } +// ServerOption configures how we set up the connection. +type ServerOption func(*Server) + +func WithShutdownCallback(shutdownCallback func()) ServerOption { + return func(srv *Server) { + srv.shutdownCallbacks = append(srv.shutdownCallbacks, shutdownCallback) + } +} + // NewServer returns a new graceServer. -func NewServer(addr string, handler http.Handler) (srv *Server) { +func NewServer(addr string, handler http.Handler, opts ...ServerOption) (srv *Server) { regLock.Lock() defer regLock.Unlock() @@ -138,7 +149,7 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { }, state: StateInit, Network: "tcp", - terminalChan: make(chan error), //no cache channel + terminalChan: make(chan error), // no cache channel } srv.Server = &http.Server{ Addr: addr, @@ -148,6 +159,10 @@ func NewServer(addr string, handler http.Handler) (srv *Server) { Handler: handler, } + for _, opt := range opts { + opt(srv) + } + runningServersOrder = append(runningServersOrder, addr) runningServers[addr] = srv return srv diff --git a/grace/server.go b/server/web/grace/server.go similarity index 76% rename from grace/server.go rename to server/web/grace/server.go index 1ce8bc7821..ec55e65d82 100644 --- a/grace/server.go +++ b/server/web/grace/server.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "io/ioutil" "log" "net" "net/http" @@ -20,33 +19,47 @@ import ( // Server embedded http.Server type Server struct { *http.Server - ln net.Listener - SignalHooks map[int]map[os.Signal][]func() - sigChan chan os.Signal - isChild bool - state uint8 - Network string - terminalChan chan error + ln net.Listener + SignalHooks map[int]map[os.Signal][]func() + sigChan chan os.Signal + isChild bool + state uint8 + Network string + terminalChan chan error + shutdownCallbacks []func() } -// Serve accepts incoming connections on the Listener l, -// creating a new service goroutine for each. +// Serve accepts incoming connections on the Listener l +// and creates a new service goroutine for each. // The service goroutines read requests and then call srv.Handler to reply to them. func (srv *Server) Serve() (err error) { + return srv.internalServe(srv.ln) +} + +func (srv *Server) ServeWithListener(ln net.Listener) (err error) { + srv.ln = ln + go srv.handleSignals() + return srv.internalServe(ln) +} + +func (srv *Server) internalServe(ln net.Listener) (err error) { srv.state = StateRunning defer func() { srv.state = StateTerminate }() // When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS // immediately return ErrServerClosed. Make sure the program doesn't exit // and waits instead for Shutdown to return. - if err = srv.Server.Serve(srv.ln); err != nil && err != http.ErrServerClosed { + if err = srv.Server.Serve(ln); err != nil && err != http.ErrServerClosed { log.Println(syscall.Getpid(), "Server.Serve() error:", err) return err } - log.Println(syscall.Getpid(), srv.ln.Addr(), "Listener closed.") + log.Println(syscall.Getpid(), ln.Addr(), "Listener closed.") // wait for Shutdown to return - return <-srv.terminalChan + if shutdownErr := <-srv.terminalChan; shutdownErr != nil { + return shutdownErr + } + return } // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve @@ -92,6 +105,15 @@ func (srv *Server) ListenAndServe() (err error) { // // If srv.Addr is blank, ":https" is used. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { + ln, err := srv.ListenTLS(certFile, keyFile) + if err != nil { + return err + } + + return srv.ServeTLS(ln) +} + +func (srv *Server) ListenTLS(certFile string, keyFile string) (net.Listener, error) { addr := srv.Addr if addr == "" { addr = ":https" @@ -105,20 +127,35 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { } srv.TLSConfig.Certificates = make([]tls.Certificate, 1) - srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return + return nil, err } + srv.TLSConfig.Certificates[0] = cert go srv.handleSignals() ln, err := srv.getListener(addr) if err != nil { log.Println(err) + return nil, err + } + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + return tlsListener, nil +} + +// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming mutual TLS connections. +func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { + ln, err := srv.ListenMutualTLS(certFile, keyFile, trustFile) + if err != nil { return err } - srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + return srv.ServeTLS(ln) +} + +func (srv *Server) ServeTLS(ln net.Listener) error { if srv.isChild { process, err := os.FindProcess(os.Getppid()) if err != nil { @@ -131,13 +168,11 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) (err error) { } } - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() + go srv.handleSignals() + return srv.internalServe(ln) } -// ListenAndServeMutualTLS listens on the TCP network address srv.Addr and then calls -// Serve to handle requests on incoming mutual TLS connections. -func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) (err error) { +func (srv *Server) ListenMutualTLS(certFile string, keyFile string, trustFile string) (net.Listener, error) { addr := srv.Addr if addr == "" { addr = ":https" @@ -151,16 +186,17 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) } srv.TLSConfig.Certificates = make([]tls.Certificate, 1) - srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return + return nil, err } + srv.TLSConfig.Certificates[0] = cert srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert pool := x509.NewCertPool() - data, err := ioutil.ReadFile(trustFile) + data, err := os.ReadFile(trustFile) if err != nil { log.Println(err) - return err + return nil, err } pool.AppendCertsFromPEM(data) srv.TLSConfig.ClientCAs = pool @@ -170,24 +206,10 @@ func (srv *Server) ListenAndServeMutualTLS(certFile, keyFile, trustFile string) ln, err := srv.getListener(addr) if err != nil { log.Println(err) - return err + return nil, err } - srv.ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) - - if srv.isChild { - process, err := os.FindProcess(os.Getppid()) - if err != nil { - log.Println(err) - return err - } - err = process.Kill() - if err != nil { - return err - } - } - - log.Println(os.Getpid(), srv.Addr) - return srv.Serve() + tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig) + return tlsListener, nil } // getListener either opens a new socket to listen on, or takes the acceptor socket @@ -289,6 +311,9 @@ func (srv *Server) shutdown() { ctx, cancel = context.WithTimeout(context.Background(), DefaultTimeout) defer cancel() } + for _, shutdownCallback := range srv.shutdownCallbacks { + shutdownCallback() + } srv.terminalChan <- srv.Server.Shutdown(ctx) } @@ -300,8 +325,8 @@ func (srv *Server) fork() (err error) { } runningServersForked = true - var files = make([]*os.File, len(runningServers)) - var orderArgs = make([]string, len(runningServers)) + files := make([]*os.File, len(runningServers)) + orderArgs := make([]string, len(runningServers)) for _, srvPtr := range runningServers { f, _ := srvPtr.ln.(*net.TCPListener).File() files[socketPtrOffsetMap[srvPtr.Server.Addr]] = f @@ -313,15 +338,15 @@ func (srv *Server) fork() (err error) { var args []string if len(os.Args) > 1 { for _, arg := range os.Args[1:] { - if arg == "-graceful" { + if strings.TrimLeft(arg, "-") == "graceful" { break } args = append(args, arg) } } - args = append(args, "-graceful") + args = append(args, "--graceful") if len(runningServers) > 1 { - args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ","))) + args = append(args, fmt.Sprintf(`--socketorder=%s`, strings.Join(orderArgs, ","))) log.Println(args) } cmd := exec.Command(path, args...) diff --git a/hooks.go b/server/web/hooks.go similarity index 85% rename from hooks.go rename to server/web/hooks.go index b8671d3530..68cc61a182 100644 --- a/hooks.go +++ b/server/web/hooks.go @@ -1,4 +1,4 @@ -package beego +package web import ( "encoding/json" @@ -6,9 +6,9 @@ import ( "net/http" "path/filepath" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/session" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/session" ) // register MIME type with content type @@ -34,6 +34,7 @@ func registerDefaultErrorHandler() error { "504": gatewayTimeout, "417": invalidxsrf, "422": missingxsrf, + "413": payloadTooLarge, } for e, h := range m { if _, ok := ErrorMaps[e]; !ok { @@ -46,9 +47,9 @@ func registerDefaultErrorHandler() error { func registerSession() error { if BConfig.WebConfig.Session.SessionOn { var err error - sessionConfig := AppConfig.String("sessionConfig") + sessionConfig, err := AppConfig.String("sessionConfig") conf := new(session.ManagerConfig) - if sessionConfig == "" { + if sessionConfig == "" || err != nil { conf.CookieName = BConfig.WebConfig.Session.SessionName conf.EnableSetCookie = BConfig.WebConfig.Session.SessionAutoSetCookie conf.Gclifetime = BConfig.WebConfig.Session.SessionGCMaxLifetime @@ -60,6 +61,8 @@ func registerSession() error { conf.EnableSidInHTTPHeader = BConfig.WebConfig.Session.SessionEnableSidInHTTPHeader conf.SessionNameInHTTPHeader = BConfig.WebConfig.Session.SessionNameInHTTPHeader conf.EnableSidInURLQuery = BConfig.WebConfig.Session.SessionEnableSidInURLQuery + conf.CookieSameSite = BConfig.WebConfig.Session.SessionCookieSameSite + conf.SessionIDPrefix = BConfig.WebConfig.Session.SessionIDPrefix } else { if err = json.Unmarshal([]byte(sessionConfig), conf); err != nil { return err @@ -84,13 +87,6 @@ func registerTemplate() error { return nil } -func registerAdmin() error { - if BConfig.Listen.EnableAdmin { - go beeAdminApp.Run() - } - return nil -} - func registerGzip() error { if BConfig.EnableGzip { context.InitGzip( diff --git a/mime.go b/server/web/mime.go similarity index 99% rename from mime.go rename to server/web/mime.go index ca2878ab25..9393e9c7b9 100644 --- a/mime.go +++ b/server/web/mime.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web var mimemaps = map[string]string{ ".3dm": "x-world/x-3dmf", diff --git a/server/web/mock/context.go b/server/web/mock/context.go new file mode 100644 index 0000000000..65df0af90b --- /dev/null +++ b/server/web/mock/context.go @@ -0,0 +1,29 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "net/http" + "net/http/httptest" + + beegoCtx "github.com/beego/beego/v2/server/web/context" +) + +func NewMockContext(req *http.Request) (*beegoCtx.Context, *httptest.ResponseRecorder) { + ctx := beegoCtx.NewContext() + resp := httptest.NewRecorder() + ctx.Reset(resp, req) + return ctx, resp +} diff --git a/server/web/mock/context_test.go b/server/web/mock/context_test.go new file mode 100644 index 0000000000..0eb8665636 --- /dev/null +++ b/server/web/mock/context_test.go @@ -0,0 +1,66 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "bytes" + "context" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web" +) + +type TestController struct { + web.Controller +} + +func TestMockContext(t *testing.T) { + req, err := http.NewRequest("GET", "https://localhost:8080/hello?name=tom", bytes.NewReader([]byte{})) + assert.Nil(t, err) + ctx, resp := NewMockContext(req) + ctrl := &TestController{ + Controller: web.Controller{ + Ctx: ctx, + }, + } + ctrl.HelloWorld() + result := resp.Body.String() + assert.Equal(t, "name=tom", result) +} + +// GET hello?name=XXX +func (c *TestController) HelloWorld() { + name := c.GetString("name") + c.Ctx.WriteString(fmt.Sprintf("name=%s", name)) +} + +func (c *TestController) HelloSession() { + err := c.SessionRegenerateID() + if err != nil { + c.Ctx.WriteString("error") + return + } + _ = c.SetSession("name", "Tom") + c.Ctx.WriteString("set") +} + +func (c *TestController) HelloSessionName() { + name := c.CruSession.Get(context.Background(), "name") + c.Ctx.WriteString(name.(string)) +} diff --git a/server/web/mock/session.go b/server/web/mock/session.go new file mode 100644 index 0000000000..4b154e3684 --- /dev/null +++ b/server/web/mock/session.go @@ -0,0 +1,128 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "net/http" + + "github.com/google/uuid" + + "github.com/beego/beego/v2/server/web" + "github.com/beego/beego/v2/server/web/session" +) + +// NewSessionProvider create new SessionProvider +// and you could use it to mock data +// Parameter "name" is the real SessionProvider you used +func NewSessionProvider(name string) *SessionProvider { + sp := newSessionProvider() + session.Register(name, sp) + web.GlobalSessions, _ = session.NewManager(name, session.NewManagerConfig()) + return sp +} + +// SessionProvider will replace session provider with "mock" provider +type SessionProvider struct { + Store *SessionStore +} + +func newSessionProvider() *SessionProvider { + return &SessionProvider{ + Store: newSessionStore(), + } +} + +// SessionInit do nothing +func (s *SessionProvider) SessionInit(ctx context.Context, gclifetime int64, config string) error { + return nil +} + +// SessionRead return Store +func (s *SessionProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) { + return s.Store, nil +} + +// SessionExist always return true +func (s *SessionProvider) SessionExist(ctx context.Context, sid string) (bool, error) { + return true, nil +} + +// SessionRegenerate create new Store +func (s *SessionProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + s.Store = newSessionStore() + return s.Store, nil +} + +// SessionDestroy reset Store to nil +func (s *SessionProvider) SessionDestroy(ctx context.Context, sid string) error { + s.Store = nil + return nil +} + +// SessionAll return 0 +func (s *SessionProvider) SessionAll(ctx context.Context) int { + return 0 +} + +// SessionGC do nothing +func (s *SessionProvider) SessionGC(ctx context.Context) { + // we do anything since we don't need to mock GC +} + +type SessionStore struct { + sid string + values map[interface{}]interface{} +} + +func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error { + s.values[key] = value + return nil +} + +func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} { + return s.values[key] +} + +func (s *SessionStore) Delete(ctx context.Context, key interface{}) error { + delete(s.values, key) + return nil +} + +func (s *SessionStore) SessionID(ctx context.Context) string { + return s.sid +} + +// SessionRelease do nothing +func (s *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { + // Support in the future if necessary, now I think we don't need to implement this +} + +// SessionReleaseIfPresent do nothing +func (*SessionStore) SessionReleaseIfPresent(_ context.Context, _ http.ResponseWriter) { + // Support in the future if necessary, now I think we don't need to implement this +} + +func (s *SessionStore) Flush(ctx context.Context) error { + s.values = make(map[interface{}]interface{}, 4) + return nil +} + +func newSessionStore() *SessionStore { + return &SessionStore{ + sid: uuid.New().String(), + values: make(map[interface{}]interface{}, 4), + } +} diff --git a/server/web/mock/session_test.go b/server/web/mock/session_test.go new file mode 100644 index 0000000000..f9eef54c6d --- /dev/null +++ b/server/web/mock/session_test.go @@ -0,0 +1,48 @@ +// Copyright 2021 beego +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "bytes" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web" +) + +func TestSessionProvider(t *testing.T) { + sp := NewSessionProvider("file") + assert.NotNil(t, sp) + + req, err := http.NewRequest("GET", "http://localhost:8080/hello?name=tom", bytes.NewReader([]byte{})) + assert.Nil(t, err) + ctx, resp := NewMockContext(req) + ctrl := &TestController{ + Controller: web.Controller{ + Ctx: ctx, + }, + } + ctrl.HelloSession() + result := resp.Body.String() + assert.Equal(t, "set", result) + + resp.Body.Reset() + ctrl.HelloSessionName() + result = resp.Body.String() + + assert.Equal(t, "Tom", result) +} diff --git a/namespace.go b/server/web/namespace.go similarity index 57% rename from namespace.go rename to server/web/namespace.go index 4952c9d568..0a9c7e6e71 100644 --- a/namespace.go +++ b/server/web/namespace.go @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" "strings" - beecontext "github.com/astaxie/beego/context" + beecontext "github.com/beego/beego/v2/server/web/context" ) type namespaceCond func(*beecontext.Context) bool @@ -47,12 +47,14 @@ func NewNamespace(prefix string, params ...LinkNamespace) *Namespace { // Cond set condition function // if cond return true can run this namespace, else can't // usage: -// ns.Cond(func (ctx *context.Context) bool{ -// if ctx.Input.Domain() == "api.beego.me" { -// return true -// } -// return false -// }) +// +// ns.Cond(func (ctx *context.Context) bool{ +// if ctx.Input.Domain() == "api.beego.vip" { +// return true +// } +// return false +// }) +// // Cond as the first filter func (n *Namespace) Cond(cond namespaceCond) *Namespace { fn := func(ctx *beecontext.Context) { @@ -77,12 +79,13 @@ func (n *Namespace) Cond(cond namespaceCond) *Namespace { // action has before & after // FilterFunc // usage: -// Filter("before", func (ctx *context.Context){ -// _, ok := ctx.Input.Session("uid").(int) -// if !ok && ctx.Request.RequestURI != "/login" { -// ctx.Redirect(302, "/login") -// } -// }) +// +// Filter("before", func (ctx *context.Context){ +// _, ok := ctx.Input.Session("uid").(int) +// if !ok && ctx.Request.RequestURI != "/login" { +// ctx.Redirect(302, "/login") +// } +// }) func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { var a int if action == "before" { @@ -91,119 +94,169 @@ func (n *Namespace) Filter(action string, filter ...FilterFunc) *Namespace { a = FinishRouter } for _, f := range filter { - n.handlers.InsertFilter("*", a, f) + n.handlers.InsertFilter("*", a, f, WithReturnOnOutput(true)) } return n } // Router same as beego.Rourer -// refer: https://godoc.org/github.com/astaxie/beego#Router +// refer: https://godoc.org/github.com/beego/beego/v2#Router func (n *Namespace) Router(rootpath string, c ControllerInterface, mappingMethods ...string) *Namespace { - n.handlers.Add(rootpath, c, mappingMethods...) + n.handlers.Add(rootpath, c, WithRouterMethods(c, mappingMethods...)) return n } // AutoRouter same as beego.AutoRouter -// refer: https://godoc.org/github.com/astaxie/beego#AutoRouter +// refer: https://godoc.org/github.com/beego/beego/v2#AutoRouter func (n *Namespace) AutoRouter(c ControllerInterface) *Namespace { n.handlers.AddAuto(c) return n } // AutoPrefix same as beego.AutoPrefix -// refer: https://godoc.org/github.com/astaxie/beego#AutoPrefix +// refer: https://godoc.org/github.com/beego/beego/v2#AutoPrefix func (n *Namespace) AutoPrefix(prefix string, c ControllerInterface) *Namespace { n.handlers.AddAutoPrefix(prefix, c) return n } // Get same as beego.Get -// refer: https://godoc.org/github.com/astaxie/beego#Get -func (n *Namespace) Get(rootpath string, f FilterFunc) *Namespace { +// refer: https://godoc.org/github.com/beego/beego/v2#Get +func (n *Namespace) Get(rootpath string, f HandleFunc) *Namespace { n.handlers.Get(rootpath, f) return n } // Post same as beego.Post -// refer: https://godoc.org/github.com/astaxie/beego#Post -func (n *Namespace) Post(rootpath string, f FilterFunc) *Namespace { +// refer: https://godoc.org/github.com/beego/beego/v2#Post +func (n *Namespace) Post(rootpath string, f HandleFunc) *Namespace { n.handlers.Post(rootpath, f) return n } // Delete same as beego.Delete -// refer: https://godoc.org/github.com/astaxie/beego#Delete -func (n *Namespace) Delete(rootpath string, f FilterFunc) *Namespace { +// refer: https://godoc.org/github.com/beego/beego/v2#Delete +func (n *Namespace) Delete(rootpath string, f HandleFunc) *Namespace { n.handlers.Delete(rootpath, f) return n } // Put same as beego.Put -// refer: https://godoc.org/github.com/astaxie/beego#Put -func (n *Namespace) Put(rootpath string, f FilterFunc) *Namespace { +// refer: https://godoc.org/github.com/beego/beego/v2#Put +func (n *Namespace) Put(rootpath string, f HandleFunc) *Namespace { n.handlers.Put(rootpath, f) return n } // Head same as beego.Head -// refer: https://godoc.org/github.com/astaxie/beego#Head -func (n *Namespace) Head(rootpath string, f FilterFunc) *Namespace { +// refer: https://godoc.org/github.com/beego/beego/v2#Head +func (n *Namespace) Head(rootpath string, f HandleFunc) *Namespace { n.handlers.Head(rootpath, f) return n } // Options same as beego.Options -// refer: https://godoc.org/github.com/astaxie/beego#Options -func (n *Namespace) Options(rootpath string, f FilterFunc) *Namespace { +// refer: https://godoc.org/github.com/beego/beego/v2#Options +func (n *Namespace) Options(rootpath string, f HandleFunc) *Namespace { n.handlers.Options(rootpath, f) return n } // Patch same as beego.Patch -// refer: https://godoc.org/github.com/astaxie/beego#Patch -func (n *Namespace) Patch(rootpath string, f FilterFunc) *Namespace { +// refer: https://godoc.org/github.com/beego/beego/v2#Patch +func (n *Namespace) Patch(rootpath string, f HandleFunc) *Namespace { n.handlers.Patch(rootpath, f) return n } // Any same as beego.Any -// refer: https://godoc.org/github.com/astaxie/beego#Any -func (n *Namespace) Any(rootpath string, f FilterFunc) *Namespace { +// refer: https://godoc.org/github.com/beego/beego/v2#Any +func (n *Namespace) Any(rootpath string, f HandleFunc) *Namespace { n.handlers.Any(rootpath, f) return n } // Handler same as beego.Handler -// refer: https://godoc.org/github.com/astaxie/beego#Handler +// refer: https://godoc.org/github.com/beego/beego/v2#Handler func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { n.handlers.Handler(rootpath, h) return n } // Include add include class -// refer: https://godoc.org/github.com/astaxie/beego#Include +// refer: https://godoc.org/github.com/beego/beego/v2#Include func (n *Namespace) Include(cList ...ControllerInterface) *Namespace { n.handlers.Include(cList...) return n } +// CtrlGet same as beego.CtrlGet +func (n *Namespace) CtrlGet(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlGet(rootpath, f) + return n +} + +// CtrlPost same as beego.CtrlPost +func (n *Namespace) CtrlPost(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlPost(rootpath, f) + return n +} + +// CtrlDelete same as beego.CtrlDelete +func (n *Namespace) CtrlDelete(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlDelete(rootpath, f) + return n +} + +// CtrlPut same as beego.CtrlPut +func (n *Namespace) CtrlPut(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlPut(rootpath, f) + return n +} + +// CtrlHead same as beego.CtrlHead +func (n *Namespace) CtrlHead(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlHead(rootpath, f) + return n +} + +// CtrlOptions same as beego.CtrlOptions +func (n *Namespace) CtrlOptions(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlOptions(rootpath, f) + return n +} + +// CtrlPatch same as beego.CtrlPatch +func (n *Namespace) CtrlPatch(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlPatch(rootpath, f) + return n +} + +// Any same as beego.CtrlAny +func (n *Namespace) CtrlAny(rootpath string, f interface{}) *Namespace { + n.handlers.CtrlAny(rootpath, f) + return n +} + // Namespace add nest Namespace // usage: -//ns := beego.NewNamespace(“/v1”). -//Namespace( -// beego.NewNamespace("/shop"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("shopinfo")) -// }), -// beego.NewNamespace("/order"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("orderinfo")) -// }), -// beego.NewNamespace("/crm"). -// Get("/:id", func(ctx *context.Context) { -// ctx.Output.Body([]byte("crminfo")) -// }), -//) +// ns := beego.NewNamespace(“/v1”). +// Namespace( +// +// beego.NewNamespace("/shop"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("shopinfo")) +// }), +// beego.NewNamespace("/order"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("orderinfo")) +// }), +// beego.NewNamespace("/crm"). +// Get("/:id", func(ctx *context.Context) { +// ctx.Output.Body([]byte("crminfo")) +// }), +// +// ) func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { for _, ni := range ns { for k, v := range ni.handlers.routers { @@ -311,61 +364,117 @@ func NSRouter(rootpath string, c ControllerInterface, mappingMethods ...string) } // NSGet call Namespace Get -func NSGet(rootpath string, f FilterFunc) LinkNamespace { +func NSGet(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Get(rootpath, f) } } // NSPost call Namespace Post -func NSPost(rootpath string, f FilterFunc) LinkNamespace { +func NSPost(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Post(rootpath, f) } } // NSHead call Namespace Head -func NSHead(rootpath string, f FilterFunc) LinkNamespace { +func NSHead(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Head(rootpath, f) } } // NSPut call Namespace Put -func NSPut(rootpath string, f FilterFunc) LinkNamespace { +func NSPut(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Put(rootpath, f) } } // NSDelete call Namespace Delete -func NSDelete(rootpath string, f FilterFunc) LinkNamespace { +func NSDelete(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Delete(rootpath, f) } } // NSAny call Namespace Any -func NSAny(rootpath string, f FilterFunc) LinkNamespace { +func NSAny(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Any(rootpath, f) } } // NSOptions call Namespace Options -func NSOptions(rootpath string, f FilterFunc) LinkNamespace { +func NSOptions(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Options(rootpath, f) } } // NSPatch call Namespace Patch -func NSPatch(rootpath string, f FilterFunc) LinkNamespace { +func NSPatch(rootpath string, f HandleFunc) LinkNamespace { return func(ns *Namespace) { ns.Patch(rootpath, f) } } +// NSCtrlGet call Namespace CtrlGet +func NSCtrlGet(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlGet(rootpath, f) + } +} + +// NSCtrlPost call Namespace CtrlPost +func NSCtrlPost(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlPost(rootpath, f) + } +} + +// NSCtrlHead call Namespace CtrlHead +func NSCtrlHead(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlHead(rootpath, f) + } +} + +// NSCtrlPut call Namespace CtrlPut +func NSCtrlPut(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlPut(rootpath, f) + } +} + +// NSCtrlDelete call Namespace CtrlDelete +func NSCtrlDelete(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlDelete(rootpath, f) + } +} + +// NSCtrlAny call Namespace CtrlAny +func NSCtrlAny(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlAny(rootpath, f) + } +} + +// NSCtrlOptions call Namespace CtrlOptions +func NSCtrlOptions(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlOptions(rootpath, f) + } +} + +// NSCtrlPatch call Namespace CtrlPatch +func NSCtrlPatch(rootpath string, f interface{}) LinkNamespace { + return func(ns *Namespace) { + ns.CtrlPatch(rootpath, f) + } +} + // NSAutoRouter call Namespace AutoRouter func NSAutoRouter(c ControllerInterface) LinkNamespace { return func(ns *Namespace) { diff --git a/server/web/namespace_test.go b/server/web/namespace_test.go new file mode 100644 index 0000000000..9451a525d7 --- /dev/null +++ b/server/web/namespace_test.go @@ -0,0 +1,415 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/beego/beego/v2/server/web/context" +) + +const ( + exampleBody = "hello world" + examplePointerBody = "hello world pointer" + + nsNamespace = "/router" + nsPath = "/user" + nsNamespacePath = "/router/user" +) + +type ExampleController struct { + Controller +} + +func (m ExampleController) Ping() { + err := m.Ctx.Output.Body([]byte(exampleBody)) + if err != nil { + fmt.Println(err) + } +} + +func (m *ExampleController) PingPointer() { + err := m.Ctx.Output.Body([]byte(examplePointerBody)) + if err != nil { + fmt.Println(err) + } +} + +func (m ExampleController) ping() { + err := m.Ctx.Output.Body([]byte("ping method")) + if err != nil { + fmt.Println(err) + } +} + +func TestNamespaceGet(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/user", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Get("/user", func(ctx *context.Context) { + ctx.Output.Body([]byte("v1_user")) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "v1_user" { + t.Errorf("TestNamespaceGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespacePost(t *testing.T) { + r, _ := http.NewRequest("POST", "/v1/user/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespacePost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNest(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/admin/order", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Namespace( + NewNamespace("/admin"). + Get("/order", func(ctx *context.Context) { + ctx.Output.Body([]byte("order")) + }), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "order" { + t.Errorf("TestNamespaceNest can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNestParam(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/admin/order/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Namespace( + NewNamespace("/admin"). + Get("/order/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespaceNestParam can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceRouter(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/api/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Router("/api/list", &TestController{}, "*:List") + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestNamespaceRouter can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceAutoFunc(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/test/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.AutoRouter(&TestController{}) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("user define func can't run") + } +} + +func TestNamespaceFilter(t *testing.T) { + r, _ := http.NewRequest("GET", "/v1/user/123", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v1") + ns.Filter("before", func(ctx *context.Context) { + ctx.Output.Body([]byte("this is Filter")) + }). + Get("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "this is Filter" { + t.Errorf("TestNamespaceFilter can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCond(t *testing.T) { + r, _ := http.NewRequest("GET", "/v2/test/list", nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/v2") + ns.Cond(func(ctx *context.Context) bool { + return ctx.Input.Domain() == "beego.vip" + }). + AutoRouter(&TestController{}) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Code != 405 { + t.Errorf("TestNamespaceCond can't run get the result " + strconv.Itoa(w.Code)) + } +} + +func TestNamespaceInside(t *testing.T) { + r, _ := http.NewRequest("GET", "/v3/shop/order/123", nil) + w := httptest.NewRecorder() + ns := NewNamespace("/v3", + NSAutoRouter(&TestController{}), + NSNamespace("/shop", + NSGet("/order/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }), + ), + ) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != "123" { + t.Errorf("TestNamespaceInside can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlGet(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlPost(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlPost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlDelete(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlDelete can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlPut(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlPut can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlHead(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlHead can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlOptions(t *testing.T) { + r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlOptions(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlOptions can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + ns.CtrlPatch(nsPath, ExampleController.Ping) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlPatch can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceCtrlAny(t *testing.T) { + ns := NewNamespace(nsNamespace) + ns.CtrlAny(nsPath, ExampleController.Ping) + AddNamespace(ns) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, nsNamespacePath, nil) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceCtrlAny can't run, get the response is " + w.Body.String()) + } + } +} + +func TestNamespaceNSCtrlGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlGet(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlGet can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace("/router") + NSCtrlPost(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlPost can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlDelete(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlDelete can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlPut(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlPut can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlHead(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlHead can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlOptions(t *testing.T) { + r, _ := http.NewRequest(http.MethodOptions, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlOptions(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlOptions can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, nsNamespacePath, nil) + w := httptest.NewRecorder() + + ns := NewNamespace(nsNamespace) + NSCtrlPatch("/user", ExampleController.Ping)(ns) + AddNamespace(ns) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlPatch can't run, get the response is " + w.Body.String()) + } +} + +func TestNamespaceNSCtrlAny(t *testing.T) { + ns := NewNamespace(nsNamespace) + NSCtrlAny(nsPath, ExampleController.Ping)(ns) + AddNamespace(ns) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, nsNamespacePath, nil) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestNamespaceNSCtrlAny can't run, get the response is " + w.Body.String()) + } + } +} diff --git a/utils/pagination/controller.go b/server/web/pagination/controller.go similarity index 81% rename from utils/pagination/controller.go rename to server/web/pagination/controller.go index 2f022d0c76..d1299768f4 100644 --- a/utils/pagination/controller.go +++ b/server/web/pagination/controller.go @@ -15,12 +15,13 @@ package pagination import ( - "github.com/astaxie/beego/context" + "github.com/beego/beego/v2/core/utils/pagination" + "github.com/beego/beego/v2/server/web/context" ) // SetPaginator Instantiates a Paginator and assigns it to context.Input.Data("paginator"). -func SetPaginator(context *context.Context, per int, nums int64) (paginator *Paginator) { - paginator = NewPaginator(context.Request, per, nums) +func SetPaginator(context *context.Context, per int, nums int64) (paginator *pagination.Paginator) { + paginator = pagination.NewPaginator(context.Request, per, nums) context.Input.SetData("paginator", &paginator) return } diff --git a/policy.go b/server/web/policy.go similarity index 96% rename from policy.go rename to server/web/policy.go index ab23f927af..41fb2669d7 100644 --- a/policy.go +++ b/server/web/policy.go @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "strings" - "github.com/astaxie/beego/context" + "github.com/beego/beego/v2/server/web/context" ) // PolicyFunc defines a policy function which is invoked before the controller handler is executed. @@ -25,7 +25,7 @@ type PolicyFunc func(*context.Context) // FindPolicy Find Router info for URL func (p *ControllerRegister) FindPolicy(cont *context.Context) []PolicyFunc { - var urlPath = cont.Input.URL() + urlPath := cont.Input.URL() if !BConfig.RouterCaseSensitive { urlPath = strings.ToLower(urlPath) } diff --git a/server/web/router.go b/server/web/router.go new file mode 100644 index 0000000000..bdaff018a4 --- /dev/null +++ b/server/web/router.go @@ -0,0 +1,1388 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "path" + "reflect" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/core/utils" + beecontext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/context/param" +) + +// default filter execution points +const ( + BeforeStatic = iota + BeforeRouter + BeforeExec + AfterExec + FinishRouter +) + +const ( + routerTypeBeego = iota + routerTypeRESTFul + routerTypeHandler +) + +var ( + // HTTPMETHOD list the supported http methods. + HTTPMETHOD = map[string]bool{ + "GET": true, + "POST": true, + "PUT": true, + "DELETE": true, + "PATCH": true, + "OPTIONS": true, + "HEAD": true, + "TRACE": true, + "CONNECT": true, + "MKCOL": true, + "COPY": true, + "MOVE": true, + "PROPFIND": true, + "PROPPATCH": true, + "LOCK": true, + "UNLOCK": true, + } + // these web.Controller's methods shouldn't reflect to AutoRouter + // see registerControllerExceptMethods + exceptMethod = initExceptMethod() + + urlPlaceholder = "{{placeholder}}" + // DefaultAccessLogFilter will skip the accesslog if return true + DefaultAccessLogFilter FilterHandler = &logFilter{} +) + +// FilterHandler is an interface for +type FilterHandler interface { + Filter(*beecontext.Context) bool +} + +// default log filter static file will not show +type logFilter struct{} + +func (l *logFilter) Filter(ctx *beecontext.Context) bool { + requestPath := path.Clean(ctx.Request.URL.Path) + if requestPath == "/favicon.ico" || requestPath == "/robots.txt" { + return true + } + for prefix := range BConfig.WebConfig.StaticDir { + if strings.HasPrefix(requestPath, prefix) { + return true + } + } + return false +} + +// ExceptMethodAppend to append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter +func ExceptMethodAppend(action string) { + exceptMethod = append(exceptMethod, action) +} + +func initExceptMethod() []string { + res := make([]string, 0, 32) + c := &Controller{} + t := reflect.TypeOf(c) + for i := 0; i < t.NumMethod(); i++ { + m := t.Method(i) + res = append(res, m.Name) + } + return res +} + +// ControllerInfo holds information about the controller. +type ControllerInfo struct { + pattern string + controllerType reflect.Type + methods map[string]string + handler http.Handler + runFunction HandleFunc + routerType int + initialize func() ControllerInterface + methodParams []*param.MethodParam + sessionOn bool +} + +type ControllerOption func(*ControllerInfo) + +func (c *ControllerInfo) GetPattern() string { + return c.pattern +} + +func (c *ControllerInfo) GetMethod() map[string]string { + return c.methods +} + +func WithRouterMethods(ctrlInterface ControllerInterface, mappingMethod ...string) ControllerOption { + return func(c *ControllerInfo) { + c.methods = parseMappingMethods(ctrlInterface, mappingMethod) + } +} + +func WithRouterSessionOn(sessionOn bool) ControllerOption { + return func(c *ControllerInfo) { + c.sessionOn = sessionOn + } +} + +type filterChainConfig struct { + pattern string + chain FilterChain + opts []FilterOpt +} + +// ControllerRegister containers registered router rules, controller handlers and filters. +type ControllerRegister struct { + routers map[string]*Tree + enablePolicy bool + enableFilter bool + policies map[string]*Tree + filters [FinishRouter + 1][]*FilterRouter + pool sync.Pool + + // the filter created by FilterChain + chainRoot *FilterRouter + + // keep registered chain and build it when serve http + filterChains []filterChainConfig + + cfg *Config +} + +// NewControllerRegister returns a new ControllerRegister. +// Usually you should not use this method +// please use NewControllerRegisterWithCfg +func NewControllerRegister() *ControllerRegister { + return NewControllerRegisterWithCfg(BeeApp.Cfg) +} + +func NewControllerRegisterWithCfg(cfg *Config) *ControllerRegister { + res := &ControllerRegister{ + routers: make(map[string]*Tree), + policies: make(map[string]*Tree), + pool: sync.Pool{ + New: func() interface{} { + return beecontext.NewContext() + }, + }, + cfg: cfg, + filterChains: make([]filterChainConfig, 0, 4), + } + res.chainRoot = newFilterRouter("/*", res.serveHttp, WithCaseSensitive(false)) + return res +} + +// Init will be executed when HttpServer start running +func (p *ControllerRegister) Init() { + for i := len(p.filterChains) - 1; i >= 0; i-- { + fc := p.filterChains[i] + root := p.chainRoot + filterFunc := fc.chain(func(ctx *beecontext.Context) { + var preFilterParams map[string]string + root.filter(ctx, p.getUrlPath(ctx), preFilterParams) + }) + p.chainRoot = newFilterRouter(fc.pattern, filterFunc, fc.opts...) + p.chainRoot.next = root + } +} + +// Add controller handler and pattern rules to ControllerRegister. +// usage: +// +// default methods is the same name as method +// Add("/user",&UserController{}) +// Add("/api/list",&RestController{},"*:ListFood") +// Add("/api/create",&RestController{},"post:CreateFood") +// Add("/api/update",&RestController{},"put:UpdateFood") +// Add("/api/delete",&RestController{},"delete:DeleteFood") +// Add("/api",&RestController{},"get,post:ApiFunc" +// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") +func (p *ControllerRegister) Add(pattern string, c ControllerInterface, opts ...ControllerOption) { + p.addWithMethodParams(pattern, c, nil, opts...) +} + +func parseMappingMethods(c ControllerInterface, mappingMethods []string) map[string]string { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + methods := make(map[string]string) + + if len(mappingMethods) == 0 { + return methods + } + + semi := strings.Split(mappingMethods[0], ";") + for _, v := range semi { + colon := strings.Split(v, ":") + if len(colon) != 2 { + panic("method mapping format is invalid") + } + comma := strings.Split(colon[0], ",") + for _, m := range comma { + if m != "*" && !HTTPMETHOD[strings.ToUpper(m)] { + panic(v + " is an invalid method mapping. Method doesn't exist " + m) + } + if val := reflectVal.MethodByName(colon[1]); val.IsValid() { + methods[strings.ToUpper(m)] = colon[1] + continue + } + panic("'" + colon[1] + "' method doesn't exist in the controller " + t.Name()) + } + } + + return methods +} + +func (p *ControllerRegister) addRouterForMethod(route *ControllerInfo) { + if len(route.methods) == 0 { + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + return + } + for k := range route.methods { + if k != "*" { + p.addToRouter(k, route.pattern, route) + continue + } + for m := range HTTPMETHOD { + p.addToRouter(m, route.pattern, route) + } + } +} + +func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, opts ...ControllerOption) { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + + route := p.createBeegoRouter(t, pattern) + route.initialize = func() ControllerInterface { + vc := reflect.New(route.controllerType) + execController, ok := vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + + elemVal := reflect.ValueOf(c).Elem() + elemType := reflect.TypeOf(c).Elem() + execElem := reflect.ValueOf(execController).Elem() + + numOfFields := elemVal.NumField() + for i := 0; i < numOfFields; i++ { + fieldType := elemType.Field(i) + elemField := execElem.FieldByName(fieldType.Name) + if elemField.CanSet() { + fieldVal := elemVal.Field(i) + elemField.Set(fieldVal) + } + } + + return execController + } + route.methodParams = methodParams + for i := range opts { + opts[i](route) + } + + globalSessionOn := p.cfg.WebConfig.Session.SessionOn + if !globalSessionOn && route.sessionOn { + logs.Warn("global sessionOn is false, sessionOn of router [%s] can't be set to true", route.pattern) + route.sessionOn = globalSessionOn + } + + p.addRouterForMethod(route) +} + +func (p *ControllerRegister) addToRouter(method, pattern string, r *ControllerInfo) { + if !p.cfg.RouterCaseSensitive { + pattern = strings.ToLower(pattern) + } + if t, ok := p.routers[method]; ok { + t.AddRouter(pattern, r) + } else { + t := NewTree() + t.AddRouter(pattern, r) + p.routers[method] = t + } +} + +// Include only when the Runmode is dev will generate router file in the router/auto.go from the controller +// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +func (p *ControllerRegister) Include(cList ...ControllerInterface) { + for _, c := range cList { + reflectVal := reflect.ValueOf(c) + t := reflect.Indirect(reflectVal).Type() + key := t.PkgPath() + ":" + t.Name() + if comm, ok := GlobalControllerRouter[key]; ok { + for _, a := range comm { + for _, f := range a.Filters { + p.InsertFilter(f.Pattern, f.Pos, f.Filter, WithReturnOnOutput(f.ReturnOnOutput), WithResetParams(f.ResetParams)) + } + p.addWithMethodParams(a.Router, c, a.MethodParams, WithRouterMethods(c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)) + } + } + } +} + +// GetContext returns a context from pool, so usually you should remember to call Reset function to clean the context +// And don't forget to give back context to pool +// example: +// +// ctx := p.GetContext() +// ctx.Reset(w, q) +// defer p.GiveBackContext(ctx) +func (p *ControllerRegister) GetContext() *beecontext.Context { + return p.pool.Get().(*beecontext.Context) +} + +// GiveBackContext put the ctx into pool so that it could be reuse +func (p *ControllerRegister) GiveBackContext(ctx *beecontext.Context) { + p.pool.Put(ctx) +} + +// CtrlGet add get method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlGet("/api/:id", MyController.Ping) +// +// If the receiver of function Ping is pointer, you should use CtrlGet("/api/:id", (*MyController).Ping) +func (p *ControllerRegister) CtrlGet(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodGet, pattern, f) +} + +// CtrlPost add post method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPost("/api/:id", MyController.Ping) +// +// If the receiver of function Ping is pointer, you should use CtrlPost("/api/:id", (*MyController).Ping) +func (p *ControllerRegister) CtrlPost(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPost, pattern, f) +} + +// CtrlHead add head method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlHead("/api/:id", MyController.Ping) +// +// If the receiver of function Ping is pointer, you should use CtrlHead("/api/:id", (*MyController).Ping) +func (p *ControllerRegister) CtrlHead(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodHead, pattern, f) +} + +// CtrlPut add put method +// usage: +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPut("/api/:id", MyController.Ping) + +func (p *ControllerRegister) CtrlPut(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPut, pattern, f) +} + +// CtrlPatch add patch method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPatch("/api/:id", MyController.Ping) +func (p *ControllerRegister) CtrlPatch(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodPatch, pattern, f) +} + +// CtrlDelete add delete method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlDelete("/api/:id", MyController.Ping) +func (p *ControllerRegister) CtrlDelete(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodDelete, pattern, f) +} + +// CtrlOptions add options method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlOptions("/api/:id", MyController.Ping) +func (p *ControllerRegister) CtrlOptions(pattern string, f interface{}) { + p.AddRouterMethod(http.MethodOptions, pattern, f) +} + +// CtrlAny add all method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlAny("/api/:id", MyController.Ping) +func (p *ControllerRegister) CtrlAny(pattern string, f interface{}) { + p.AddRouterMethod("*", pattern, f) +} + +// AddRouterMethod add http method router +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// AddRouterMethod("get","/api/:id", MyController.Ping) +func (p *ControllerRegister) AddRouterMethod(httpMethod, pattern string, f interface{}) { + httpMethod = p.getUpperMethodString(httpMethod) + ct, methodName := getReflectTypeAndMethod(f) + + p.addBeegoTypeRouter(ct, methodName, httpMethod, pattern) +} + +// addBeegoTypeRouter add beego type router +func (p *ControllerRegister) addBeegoTypeRouter(ct reflect.Type, ctMethod, httpMethod, pattern string) { + route := p.createBeegoRouter(ct, pattern) + methods := p.getHttpMethodMapMethod(httpMethod, ctMethod) + route.methods = methods + + p.addRouterForMethod(route) +} + +// createBeegoRouter create beego router base on reflect type and pattern +func (p *ControllerRegister) createBeegoRouter(ct reflect.Type, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeBeego + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.controllerType = ct + return route +} + +// createRestfulRouter create restful router with filter function and pattern +func (p *ControllerRegister) createRestfulRouter(f HandleFunc, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeRESTFul + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.runFunction = f + return route +} + +// createHandlerRouter create handler router with handler and pattern +func (p *ControllerRegister) createHandlerRouter(h http.Handler, pattern string) *ControllerInfo { + route := &ControllerInfo{} + route.pattern = pattern + route.routerType = routerTypeHandler + route.sessionOn = p.cfg.WebConfig.Session.SessionOn + route.handler = h + return route +} + +// getHttpMethodMapMethod based on http method and controller method, if ctMethod is empty, then it will +// use http method as the controller method +func (p *ControllerRegister) getHttpMethodMapMethod(httpMethod, ctMethod string) map[string]string { + methods := make(map[string]string) + // not match-all sign, only add for the http method + if httpMethod != "*" { + + if ctMethod == "" { + ctMethod = httpMethod + } + methods[httpMethod] = ctMethod + return methods + } + + // add all http method + for val := range HTTPMETHOD { + if ctMethod == "" { + methods[val] = val + } else { + methods[val] = ctMethod + } + } + return methods +} + +// getUpperMethodString get upper string of method, and panic if the method +// is not valid +func (p *ControllerRegister) getUpperMethodString(method string) string { + method = strings.ToUpper(method) + if method != "*" && !HTTPMETHOD[method] { + panic("not support http method: " + method) + } + return method +} + +// get reflect controller type and method by controller method expression +func getReflectTypeAndMethod(f interface{}) (controllerType reflect.Type, method string) { + // check f is a function + funcType := reflect.TypeOf(f) + if funcType.Kind() != reflect.Func { + panic("not a method") + } + + // get function name + funcObj := runtime.FuncForPC(reflect.ValueOf(f).Pointer()) + if funcObj == nil { + panic("cannot find the method") + } + funcNameSli := strings.Split(funcObj.Name(), ".") + lFuncSli := len(funcNameSli) + if lFuncSli == 0 { + panic("invalid method full name: " + funcObj.Name()) + } + + method = funcNameSli[lFuncSli-1] + if len(method) == 0 { + panic("method name is empty") + } else if method[0] > 96 || method[0] < 65 { + panic(fmt.Sprintf("%s is not a public method", method)) + } + + // check only one param which is the method receiver + if numIn := funcType.NumIn(); numIn != 1 { + panic("invalid number of param in") + } + + controllerType = funcType.In(0) + + // check controller has the method + _, exists := controllerType.MethodByName(method) + if !exists { + panic(controllerType.String() + " has no method " + method) + } + + // check the receiver implement ControllerInterface + if controllerType.Kind() == reflect.Ptr { + controllerType = controllerType.Elem() + } + controller := reflect.New(controllerType) + _, ok := controller.Interface().(ControllerInterface) + if !ok { + panic(controllerType.String() + " is not implemented ControllerInterface") + } + + return +} + +// HandleFunc define how to process the request +type HandleFunc func(ctx *beecontext.Context) + +// Get add get method +// usage: +// +// Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Get(pattern string, f HandleFunc) { + p.AddMethod("get", pattern, f) +} + +// Post add post method +// usage: +// +// Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Post(pattern string, f HandleFunc) { + p.AddMethod("post", pattern, f) +} + +// Put add put method +// usage: +// +// Put("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Put(pattern string, f HandleFunc) { + p.AddMethod("put", pattern, f) +} + +// Delete add delete method +// usage: +// +// Delete("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Delete(pattern string, f HandleFunc) { + p.AddMethod("delete", pattern, f) +} + +// Head add head method +// usage: +// +// Head("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Head(pattern string, f HandleFunc) { + p.AddMethod("head", pattern, f) +} + +// Patch add patch method +// usage: +// +// Patch("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Patch(pattern string, f HandleFunc) { + p.AddMethod("patch", pattern, f) +} + +// Options add options method +// usage: +// +// Options("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Options(pattern string, f HandleFunc) { + p.AddMethod("options", pattern, f) +} + +// Any add all method +// usage: +// +// Any("/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) Any(pattern string, f HandleFunc) { + p.AddMethod("*", pattern, f) +} + +// AddMethod add http method router +// usage: +// +// AddMethod("get","/api/:id", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (p *ControllerRegister) AddMethod(method, pattern string, f HandleFunc) { + method = p.getUpperMethodString(method) + + route := p.createRestfulRouter(f, pattern) + methods := p.getHttpMethodMapMethod(method, "") + route.methods = methods + + p.addRouterForMethod(route) +} + +// Handler add user defined Handler +func (p *ControllerRegister) Handler(pattern string, h http.Handler, options ...interface{}) { + route := p.createHandlerRouter(h, pattern) + if len(options) > 0 { + if _, ok := options[0].(bool); ok { + pattern = path.Join(pattern, "?:all(.*)") + } + } + for m := range HTTPMETHOD { + p.addToRouter(m, pattern, route) + } +} + +// AddAuto router to ControllerRegister. +// example beego.AddAuto(&MainController{}), +// MainController has method List and Page. +// visit the url /main/list to execute List function +// /main/page to execute Page function. +func (p *ControllerRegister) AddAuto(c ControllerInterface) { + p.AddAutoPrefix("/", c) +} + +// AddAutoPrefix Add auto router to ControllerRegister with prefix. +// example beego.AddAutoPrefix("/admin",&MainController{}), +// MainController has method List and Page. +// visit the url /admin/main/list to execute List function +// /admin/main/page to execute Page function. +func (p *ControllerRegister) AddAutoPrefix(prefix string, c ControllerInterface) { + reflectVal := reflect.ValueOf(c) + rt := reflectVal.Type() + ct := reflect.Indirect(reflectVal).Type() + controllerName := strings.TrimSuffix(ct.Name(), p.cfg.ControllerSuffix) + for i := 0; i < rt.NumMethod(); i++ { + methodName := rt.Method(i).Name + if !utils.InSlice(methodName, exceptMethod) { + p.addAutoPrefixMethod(prefix, controllerName, methodName, ct) + } + } +} + +func (p *ControllerRegister) addAutoPrefixMethod(prefix, controllerName, methodName string, ctrl reflect.Type) { + pattern := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(methodName), "*") + patternInit := path.Join(prefix, controllerName, methodName, "*") + patternFix := path.Join(prefix, strings.ToLower(controllerName), strings.ToLower(methodName)) + patternFixInit := path.Join(prefix, controllerName, methodName) + + route := p.createBeegoRouter(ctrl, pattern) + route.methods = map[string]string{"*": methodName} + for m := range HTTPMETHOD { + + p.addToRouter(m, pattern, route) + + // only case sensitive, we add three more routes + if p.cfg.RouterCaseSensitive { + p.addToRouter(m, patternInit, route) + p.addToRouter(m, patternFix, route) + p.addToRouter(m, patternFixInit, route) + } + } +} + +// InsertFilter Add a FilterFunc with pattern rule and action constant. +// params is for: +// 1. setting the returnOnOutput value (false allows multiple filters to execute) +// 2. determining whether or not params need to be reset. +func (p *ControllerRegister) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) error { + opts = append(opts, WithCaseSensitive(p.cfg.RouterCaseSensitive)) + mr := newFilterRouter(pattern, filter, opts...) + return p.insertFilterRouter(pos, mr) +} + +// InsertFilterChain is similar to InsertFilter, +// but it will using chainRoot.filterFunc as input to build a new filterFunc +// for example, assume that chainRoot is funcA +// and we add new FilterChain +// +// fc := func(next) { +// return func(ctx) { +// // do something +// next(ctx) +// // do something +// } +// } +func (p *ControllerRegister) InsertFilterChain(pattern string, chain FilterChain, opts ...FilterOpt) { + opts = append([]FilterOpt{WithCaseSensitive(p.cfg.RouterCaseSensitive)}, opts...) + p.filterChains = append(p.filterChains, filterChainConfig{ + pattern: pattern, + chain: chain, + opts: opts, + }) +} + +// add Filter into +func (p *ControllerRegister) insertFilterRouter(pos int, mr *FilterRouter) (err error) { + if pos < BeforeStatic || pos > FinishRouter { + return errors.New("can not find your filter position") + } + p.enableFilter = true + p.filters[pos] = append(p.filters[pos], mr) + return nil +} + +// URLFor does another controller handler in this request function. +// it can access any controller method. +func (p *ControllerRegister) URLFor(endpoint string, values ...interface{}) string { + paths := strings.Split(endpoint, ".") + if len(paths) <= 1 { + logs.Warn("urlfor endpoint must like path.controller.method") + return "" + } + if len(values)%2 != 0 { + logs.Warn("urlfor params must key-value pair") + return "" + } + params := make(map[string]string) + if len(values) > 0 { + key := "" + for k, v := range values { + if k%2 == 0 { + key = fmt.Sprint(v) + } else { + params[key] = fmt.Sprint(v) + } + } + } + controllerName := strings.Join(paths[:len(paths)-1], "/") + methodName := paths[len(paths)-1] + for m, t := range p.routers { + ok, url := p.getURL(t, "/", controllerName, methodName, params, m) + if ok { + return url + } + } + return "" +} + +func (p *ControllerRegister) getURL(t *Tree, url, controllerName, methodName string, params map[string]string, httpMethod string) (bool, string) { + for _, subtree := range t.fixrouters { + u := path.Join(url, subtree.prefix) + ok, u := p.getURL(subtree, u, controllerName, methodName, params, httpMethod) + if ok { + return ok, u + } + } + if t.wildcard != nil { + u := path.Join(url, urlPlaceholder) + ok, u := p.getURL(t.wildcard, u, controllerName, methodName, params, httpMethod) + if ok { + return ok, u + } + } + for _, l := range t.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + if c.routerType == routerTypeBeego && + strings.HasSuffix(path.Join(c.controllerType.PkgPath(), c.controllerType.Name()), `/`+controllerName) { + find := false + if HTTPMETHOD[strings.ToUpper(methodName)] { + if len(c.methods) == 0 { + find = true + } else if m, ok := c.methods[strings.ToUpper(methodName)]; ok && m == strings.ToUpper(methodName) { + find = true + } else if m, ok = c.methods["*"]; ok && m == methodName { + find = true + } + } + if !find { + for m, md := range c.methods { + if (m == "*" || m == httpMethod) && md == methodName { + find = true + } + } + } + if find { + if l.regexps == nil { + if len(l.wildcards) == 0 { + return true, strings.Replace(url, "/"+urlPlaceholder, "", 1) + toURL(params) + } + if len(l.wildcards) == 1 { + if v, ok := params[l.wildcards[0]]; ok { + delete(params, l.wildcards[0]) + return true, strings.Replace(url, urlPlaceholder, v, 1) + toURL(params) + } + return false, "" + } + if len(l.wildcards) == 3 && l.wildcards[0] == "." { + if p, ok := params[":path"]; ok { + if e, isok := params[":ext"]; isok { + delete(params, ":path") + delete(params, ":ext") + return true, strings.Replace(url, urlPlaceholder, p+"."+e, -1) + toURL(params) + } + } + } + canSkip := false + for _, v := range l.wildcards { + if v == ":" { + canSkip = true + continue + } + if u, ok := params[v]; ok { + delete(params, v) + url = strings.Replace(url, urlPlaceholder, u, 1) + } else { + if canSkip { + canSkip = false + continue + } + return false, "" + } + } + return true, url + toURL(params) + } + var i int + var startReg bool + regURL := "" + for _, v := range strings.Trim(l.regexps.String(), "^$") { + if v == '(' { + startReg = true + continue + } else if v == ')' { + startReg = false + if v, ok := params[l.wildcards[i]]; ok { + delete(params, l.wildcards[i]) + regURL = regURL + v + i++ + } else { + break + } + } else if !startReg { + regURL = string(append([]rune(regURL), v)) + } + } + if l.regexps.MatchString(regURL) { + ps := strings.Split(regURL, "/") + for _, p := range ps { + url = strings.Replace(url, urlPlaceholder, p, 1) + } + return true, url + toURL(params) + } + } + } + } + } + + return false, "" +} + +func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath string, pos int) (started bool) { + var preFilterParams map[string]string + for _, filterR := range p.filters[pos] { + b, done := filterR.filter(context, urlPath, preFilterParams) + if done { + return b + } + } + return false +} + +// Implement http.Handler interface. +func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + ctx := p.GetContext() + + ctx.Reset(rw, r) + defer p.GiveBackContext(ctx) + + var preFilterParams map[string]string + p.chainRoot.filter(ctx, p.getUrlPath(ctx), preFilterParams) +} + +func (p *ControllerRegister) serveHttp(ctx *beecontext.Context) { + var err error + startTime := time.Now() + r := ctx.Request + rw := ctx.ResponseWriter.ResponseWriter + var ( + runRouter reflect.Type + findRouter bool + runMethod string + methodParams []*param.MethodParam + routerInfo *ControllerInfo + isRunnable bool + currentSessionOn bool + originRouterInfo *ControllerInfo + originFindRouter bool + ) + + if p.cfg.RecoverFunc != nil { + defer p.cfg.RecoverFunc(ctx, p.cfg) + } + + ctx.Output.EnableGzip = p.cfg.EnableGzip + + if p.cfg.RunMode == DEV { + ctx.Output.Header("Server", p.cfg.ServerName) + } + + urlPath := p.getUrlPath(ctx) + + // filter wrong http method + if !HTTPMETHOD[r.Method] { + exception("405", ctx) + goto Admin + } + + // filter for static file + if len(p.filters[BeforeStatic]) > 0 && p.execFilter(ctx, urlPath, BeforeStatic) { + goto Admin + } + + serverStaticRouter(ctx) + + if ctx.ResponseWriter.Started { + findRouter = true + goto Admin + } + + if r.Method != http.MethodGet && r.Method != http.MethodHead { + body := ctx.Input.Context.Request.Body + if body == nil { + body = io.NopCloser(bytes.NewReader([]byte{})) + } + + if ctx.Input.IsUpload() { + ctx.Input.Context.Request.Body = http.MaxBytesReader(ctx.Input.Context.ResponseWriter, + body, + p.cfg.MaxUploadSize) + } else if p.cfg.CopyRequestBody { + // connection will close if the incoming data are larger (RFC 7231, 6.5.11) + if r.ContentLength > p.cfg.MaxMemory { + logs.Error(errors.New("payload too large")) + exception("413", ctx) + goto Admin + } + ctx.Input.CopyBody(p.cfg.MaxMemory) + } else { + ctx.Input.Context.Request.Body = http.MaxBytesReader(ctx.Input.Context.ResponseWriter, + body, + p.cfg.MaxMemory) + } + + err = ctx.Input.ParseFormOrMultiForm(p.cfg.MaxMemory) + if err != nil { + logs.Error(err) + if strings.Contains(err.Error(), `http: request body too large`) { + exception("413", ctx) + } else { + exception("500", ctx) + } + goto Admin + } + } + + // session init + currentSessionOn = p.cfg.WebConfig.Session.SessionOn + originRouterInfo, originFindRouter = p.FindRouter(ctx) + if originFindRouter { + currentSessionOn = originRouterInfo.sessionOn + } + if currentSessionOn { + ctx.Input.CruSession, err = GlobalSessions.SessionStart(rw, r) + if err != nil { + logs.Error(err) + exception("503", ctx) + goto Admin + } + defer func() { + if ctx.Input.CruSession != nil { + ctx.Input.CruSession.SessionRelease(context.Background(), rw) + } + }() + } + if len(p.filters[BeforeRouter]) > 0 && p.execFilter(ctx, urlPath, BeforeRouter) { + goto Admin + } + // User can define RunController and RunMethod in filter + if ctx.Input.RunController != nil && ctx.Input.RunMethod != "" { + findRouter = true + runMethod = ctx.Input.RunMethod + runRouter = ctx.Input.RunController + } else { + routerInfo, findRouter = p.FindRouter(ctx) + } + + // if no matches to url, throw a not found exception + if !findRouter { + exception("404", ctx) + goto Admin + } + if splat := ctx.Input.Param(":splat"); splat != "" { + for k, v := range strings.Split(splat, "/") { + ctx.Input.SetParam(strconv.Itoa(k), v) + } + } + + if routerInfo != nil { + // store router pattern into context + ctx.Input.SetData("RouterPattern", routerInfo.pattern) + } + + // execute middleware filters + if len(p.filters[BeforeExec]) > 0 && p.execFilter(ctx, urlPath, BeforeExec) { + goto Admin + } + + // check policies + if p.execPolicy(ctx, urlPath) { + goto Admin + } + + if routerInfo != nil { + if routerInfo.routerType == routerTypeRESTFul { + if _, ok := routerInfo.methods[r.Method]; ok { + isRunnable = true + routerInfo.runFunction(ctx) + } else { + exception("405", ctx) + goto Admin + } + } else if routerInfo.routerType == routerTypeHandler { + isRunnable = true + routerInfo.handler.ServeHTTP(ctx.ResponseWriter, ctx.Request) + } else { + runRouter = routerInfo.controllerType + methodParams = routerInfo.methodParams + method := r.Method + if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodPut { + method = http.MethodPut + } + if r.Method == http.MethodPost && ctx.Input.Query("_method") == http.MethodDelete { + method = http.MethodDelete + } + if m, ok := routerInfo.methods[method]; ok { + runMethod = m + } else if m, ok = routerInfo.methods["*"]; ok { + runMethod = m + } else { + runMethod = method + } + } + } + + // also defined runRouter & runMethod from filter + if !isRunnable { + // Invoke the request handler + var execController ControllerInterface + if routerInfo != nil && routerInfo.initialize != nil { + execController = routerInfo.initialize() + } else { + vc := reflect.New(runRouter) + var ok bool + execController, ok = vc.Interface().(ControllerInterface) + if !ok { + panic("controller is not ControllerInterface") + } + } + + // call the controller init function + execController.Init(ctx, runRouter.Name(), runMethod, execController) + + // call prepare function + execController.Prepare() + + // if XSRF is Enable then check cookie where there has any cookie in the request's cookie _csrf + if p.cfg.WebConfig.EnableXSRF { + execController.XSRFToken() + if r.Method == http.MethodPost || r.Method == http.MethodDelete || r.Method == http.MethodPut || + (r.Method == http.MethodPost && (ctx.Input.Query("_method") == http.MethodDelete || ctx.Input.Query("_method") == http.MethodPut)) { + execController.CheckXSRFCookie() + } + } + + execController.URLMapping() + + if !ctx.ResponseWriter.Started { + // exec main logic + switch runMethod { + case http.MethodGet: + execController.Get() + case http.MethodPost: + execController.Post() + case http.MethodDelete: + execController.Delete() + case http.MethodPut: + execController.Put() + case http.MethodHead: + execController.Head() + case http.MethodPatch: + execController.Patch() + case http.MethodOptions: + execController.Options() + case http.MethodTrace: + execController.Trace() + default: + if !execController.HandlerFunc(runMethod) { + vc := reflect.ValueOf(execController) + method := vc.MethodByName(runMethod) + in := param.ConvertParams(methodParams, method.Type(), ctx) + out := method.Call(in) + + // For backward compatibility we only handle response if we had incoming methodParams + if methodParams != nil { + p.handleParamResponse(ctx, execController, out) + } + } + } + + // render template + if !ctx.ResponseWriter.Started && ctx.Output.Status == 0 { + if p.cfg.WebConfig.AutoRender { + if err := execController.Render(); err != nil { + logs.Error(err) + } + } + } + } + + // finish all runRouter. release resource + execController.Finish() + } + + // execute middleware filters + if len(p.filters[AfterExec]) > 0 && p.execFilter(ctx, urlPath, AfterExec) { + goto Admin + } + + if len(p.filters[FinishRouter]) > 0 && p.execFilter(ctx, urlPath, FinishRouter) { + goto Admin + } + +Admin: + // admin module record QPS + + statusCode := ctx.ResponseWriter.Status + if statusCode == 0 { + statusCode = 200 + } + + LogAccess(ctx, &startTime, statusCode) + + timeDur := time.Since(startTime) + ctx.ResponseWriter.Elapsed = timeDur + if p.cfg.Listen.EnableAdmin { + pattern := "" + if routerInfo != nil { + pattern = routerInfo.pattern + } + + if FilterMonitorFunc(r.Method, r.URL.Path, timeDur, pattern, statusCode) { + routerName := "" + if runRouter != nil { + routerName = runRouter.Name() + } + go StatisticsMap.AddStatistics(r.Method, r.URL.Path, routerName, timeDur) + } + } + + if p.cfg.RunMode == DEV && !p.cfg.Log.AccessLogs { + match := map[bool]string{true: "match", false: "nomatch"} + devInfo := fmt.Sprintf("|%15s|%s %3d %s|%13s|%8s|%s %-7s %s %-3s", + ctx.Input.IP(), + logs.ColorByStatus(statusCode), statusCode, logs.ResetColor(), + timeDur.String(), + match[findRouter], + logs.ColorByMethod(r.Method), r.Method, logs.ResetColor(), + r.URL.Path) + if routerInfo != nil { + devInfo += fmt.Sprintf(" r:%s", routerInfo.pattern) + } + + logs.Debug(devInfo) + } + // Call WriteHeader if status code has been set changed + if ctx.Output.Status != 0 { + ctx.ResponseWriter.WriteHeader(ctx.Output.Status) + } +} + +func (p *ControllerRegister) getUrlPath(ctx *beecontext.Context) string { + urlPath := ctx.Request.URL.Path + if !p.cfg.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + return urlPath +} + +func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) { + // looping in reverse order for the case when both error and value are returned and error sets the response status code + for i := len(results) - 1; i >= 0; i-- { + result := results[i] + if result.Kind() != reflect.Interface || !result.IsNil() { + resultValue := result.Interface() + context.RenderMethodResult(resultValue) + } + } + if !context.ResponseWriter.Started && len(results) > 0 && context.Output.Status == 0 { + context.Output.SetStatus(200) + } +} + +// FindRouter Find Router info for URL +func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { + urlPath := context.Input.URL() + if !p.cfg.RouterCaseSensitive { + urlPath = strings.ToLower(urlPath) + } + httpMethod := context.Input.Method() + if t, ok := p.routers[httpMethod]; ok { + runObject := t.Match(urlPath, context) + if r, ok := runObject.(*ControllerInfo); ok { + return r, true + } + } + return +} + +// GetAllControllerInfo get all ControllerInfo +func (p *ControllerRegister) GetAllControllerInfo() (routerInfos []*ControllerInfo) { + for _, webTree := range p.routers { + composeControllerInfos(webTree, &routerInfos) + } + return +} + +func composeControllerInfos(tree *Tree, routerInfos *[]*ControllerInfo) { + for _, subTree := range tree.fixrouters { + composeControllerInfos(subTree, routerInfos) + } + if tree.wildcard != nil { + composeControllerInfos(tree.wildcard, routerInfos) + } + for _, l := range tree.leaves { + if c, ok := l.runObject.(*ControllerInfo); ok { + *routerInfos = append(*routerInfos, c) + } + } +} + +func toURL(params map[string]string) string { + if len(params) == 0 { + return "" + } + u := "?" + for k, v := range params { + u += k + "=" + v + "&" + } + return strings.TrimRight(u, "&") +} + +// LogAccess logging info HTTP Access +func LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + BeeApp.LogAccess(ctx, startTime, statusCode) +} diff --git a/router_test.go b/server/web/router_test.go similarity index 53% rename from router_test.go rename to server/web/router_test.go index 2797b33a0b..13f3e4785c 100644 --- a/router_test.go +++ b/server/web/router_test.go @@ -12,18 +12,48 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( + "bytes" + "fmt" "net/http" "net/http/httptest" + "reflect" + "sort" "strings" "testing" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web/context" ) +type PrefixTestController struct { + Controller +} + +func (ptc *PrefixTestController) PrefixList() { + ptc.Ctx.Output.Body([]byte("i am list in prefix test")) +} + +type TestControllerWithInterface struct{} + +func (m TestControllerWithInterface) Ping() { + fmt.Println("pong") +} + +func (m *TestControllerWithInterface) PingPointer() { + fmt.Println("pong pointer") +} + +type TestDefineController struct { + Controller +} + +func (tc *TestDefineController) List() { + tc.Ctx.Output.Body([]byte("i am list")) +} + type TestController struct { Controller } @@ -71,7 +101,6 @@ func (tc *TestController) GetEmptyBody() { tc.Ctx.Output.Body(res) } - type JSONController struct { Controller } @@ -86,10 +115,24 @@ func (jc *JSONController) Get() { jc.Ctx.Output.Body([]byte("ok")) } +func TestPrefixUrlFor(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/my/prefix/list", &PrefixTestController{}, WithRouterMethods(&PrefixTestController{}, "get:PrefixList")) + + if a := handler.URLFor(`PrefixTestController.PrefixList`); a != `/my/prefix/list` { + logs.Info(a) + t.Errorf("PrefixTestController.PrefixList must equal to /my/prefix/list") + } + if a := handler.URLFor(`TestController.PrefixList`); a != `` { + logs.Info(a) + t.Errorf("TestController.PrefixList must equal to empty string") + } +} + func TestUrlFor(t *testing.T) { handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") - handler.Add("/person/:last/:first", &TestController{}, "*:Param") + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) if a := handler.URLFor("TestController.List"); a != "/api/list" { logs.Info(a) t.Errorf("TestController.List must equal to /api/list") @@ -102,19 +145,21 @@ func TestUrlFor(t *testing.T) { func TestUrlFor3(t *testing.T) { handler := NewControllerRegister() handler.AddAuto(&TestController{}) - if a := handler.URLFor("TestController.Myext"); a != "/test/myext" && a != "/Test/Myext" { + a := handler.URLFor("TestController.Myext") + if a != "/test/myext" && a != "/Test/Myext" { t.Errorf("TestController.Myext must equal to /test/myext, but get " + a) } - if a := handler.URLFor("TestController.GetURL"); a != "/test/geturl" && a != "/Test/GetURL" { + a = handler.URLFor("TestController.GetURL") + if a != "/test/geturl" && a != "/Test/GetURL" { t.Errorf("TestController.GetURL must equal to /test/geturl, but get " + a) } } func TestUrlFor2(t *testing.T) { handler := NewControllerRegister() - handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, "*:List") - handler.Add("/v1/:username/edit", &TestController{}, "get:GetURL") - handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, "*:Param") + handler.Add("/v1/:v/cms_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) + handler.Add("/v1/:username/edit", &TestController{}, WithRouterMethods(&TestController{}, "get:GetURL")) + handler.Add("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", &TestController{}, WithRouterMethods(&TestController{}, "*:Param")) handler.Add("/:year:int/:month:int/:title/:entid", &TestController{}) if handler.URLFor("TestController.GetURL", ":username", "astaxie") != "/v1/astaxie/edit" { logs.Info(handler.URLFor("TestController.GetURL")) @@ -144,7 +189,7 @@ func TestUserFunc(t *testing.T) { w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/api/list", &TestController{}, "*:List") + handler.Add("/api/list", &TestController{}, WithRouterMethods(&TestController{}, "*:List")) handler.ServeHTTP(w, r) if w.Body.String() != "i am list" { t.Errorf("user define func can't run") @@ -211,13 +256,28 @@ func TestAutoExtFunc(t *testing.T) { } } -func TestRouteOk(t *testing.T) { +func TestEscape(t *testing.T) { + r, _ := http.NewRequest("GET", "/search/%E4%BD%A0%E5%A5%BD", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Get("/search/:keyword(.+)", func(ctx *context.Context) { + value := ctx.Input.Param(":keyword") + ctx.Output.Body([]byte(value)) + }) + handler.ServeHTTP(w, r) + str := w.Body.String() + if str != "你好" { + t.Errorf("incorrect, %s", str) + } +} +func TestRouteOk(t *testing.T) { r, _ := http.NewRequest("GET", "/person/anderson/thomas?learn=kungfu", nil) w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/person/:last/:first", &TestController{}, "get:GetParams") + handler.Add("/person/:last/:first", &TestController{}, WithRouterMethods(&TestController{}, "get:GetParams")) handler.ServeHTTP(w, r) body := w.Body.String() if body != "anderson+thomas+kungfu" { @@ -226,12 +286,11 @@ func TestRouteOk(t *testing.T) { } func TestManyRoute(t *testing.T) { - r, _ := http.NewRequest("GET", "/beego32-12.html", nil) w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, "get:GetManyRouter") + handler.Add("/beego:id([0-9]+)-:page([0-9]+).html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetManyRouter")) handler.ServeHTTP(w, r) body := w.Body.String() @@ -243,12 +302,11 @@ func TestManyRoute(t *testing.T) { // Test for issue #1669 func TestEmptyResponse(t *testing.T) { - r, _ := http.NewRequest("GET", "/beego-empty.html", nil) w := httptest.NewRecorder() handler := NewControllerRegister() - handler.Add("/beego-empty.html", &TestController{}, "get:GetEmptyBody") + handler.Add("/beego-empty.html", &TestController{}, WithRouterMethods(&TestController{}, "get:GetEmptyBody")) handler.ServeHTTP(w, r) if body := w.Body.String(); body != "" { @@ -306,7 +364,22 @@ func TestAutoPrefix(t *testing.T) { } } -func TestRouterGet(t *testing.T) { +func TestAutoPrefixWithDefinedControllerSuffix(t *testing.T) { + r, _ := http.NewRequest("GET", "/admin/test/list", nil) + w := httptest.NewRecorder() + + config := BeeApp.Cfg + config.ControllerSuffix = "DefineController" + + handler := NewControllerRegisterWithCfg(config) + handler.AddAutoPrefix("/admin", &TestDefineController{}) + handler.ServeHTTP(w, r) + if w.Body.String() != "i am list" { + t.Errorf("TestAutoPrefixWithDefinedControllerSuffix can't run") + } +} + +func TestCtrlGet(t *testing.T) { r, _ := http.NewRequest("GET", "/user", nil) w := httptest.NewRecorder() @@ -316,11 +389,11 @@ func TestRouterGet(t *testing.T) { }) handler.ServeHTTP(w, r) if w.Body.String() != "Get userlist" { - t.Errorf("TestRouterGet can't run") + t.Errorf("TestCtrlGet can't run") } } -func TestRouterPost(t *testing.T) { +func TestCtrlPost(t *testing.T) { r, _ := http.NewRequest("POST", "/user/123", nil) w := httptest.NewRecorder() @@ -330,7 +403,7 @@ func TestRouterPost(t *testing.T) { }) handler.ServeHTTP(w, r) if w.Body.String() != "123" { - t.Errorf("TestRouterPost can't run") + t.Errorf("TestCtrlPost can't run") } } @@ -363,7 +436,7 @@ func TestRouterHandlerAll(t *testing.T) { } // -// Benchmarks NewApp: +// Benchmarks NewHttpSever: // func beegoFilterFunc(ctx *context.Context) { @@ -422,7 +495,7 @@ func TestInsertFilter(t *testing.T) { testName := "TestInsertFilter" mux := NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true)) if !mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing no variadic params should set returnOnOutput to true", @@ -435,7 +508,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, false) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(false)) if mux.filters[BeforeRouter][0].returnOnOutput { t.Errorf( "%s: passing false as 1st variadic param should set returnOnOutput to false", @@ -443,7 +516,7 @@ func TestInsertFilter(t *testing.T) { } mux = NewControllerRegister() - mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, true, true) + mux.InsertFilter("*", BeforeRouter, func(*context.Context) {}, WithReturnOnOutput(true), WithResetParams(true)) if !mux.filters[BeforeRouter][0].resetParams { t.Errorf( "%s: passing true as 2nd variadic param should set resetParams to true", @@ -460,7 +533,7 @@ func TestParamResetFilter(t *testing.T) { mux := NewControllerRegister() - mux.InsertFilter("*", BeforeExec, beegoResetParams, true, true) + mux.InsertFilter("*", BeforeExec, beegoResetParams, WithReturnOnOutput(true), WithResetParams(true)) mux.Get(route, beegoHandleResetParams) @@ -468,7 +541,7 @@ func TestParamResetFilter(t *testing.T) { mux.ServeHTTP(rw, r) // The two functions, `beegoResetParams` and `beegoHandleResetParams` add - // a response header of `Splat`. The expectation here is that that Header + // a response header of `Splat`. The expectation here is that Header // value should match what the _request's_ router set, not the filter's. headers := rw.Result().Header @@ -513,8 +586,8 @@ func TestFilterBeforeExec(t *testing.T) { url := "/beforeExec" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoBeforeExec1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoBeforeExec1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -541,7 +614,7 @@ func TestFilterAfterExec(t *testing.T) { mux := NewControllerRegister() mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoAfterExec1, false) + mux.InsertFilter(url, AfterExec, beegoAfterExec1, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) @@ -569,10 +642,10 @@ func TestFilterFinishRouter(t *testing.T) { url := "/finishRouter" mux := NewControllerRegister() - mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput) - mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput) - mux.InsertFilter(url, AfterExec, beegoFilterNoOutput) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1) + mux.InsertFilter(url, BeforeRouter, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, BeforeExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, AfterExec, beegoFilterNoOutput, WithReturnOnOutput(true)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(true)) mux.Get(url, beegoFilterFunc) @@ -603,7 +676,7 @@ func TestFilterFinishRouterMultiFirstOnly(t *testing.T) { url := "/finishRouterMultiFirstOnly" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) mux.InsertFilter(url, FinishRouter, beegoFinishRouter2) mux.Get(url, beegoFilterFunc) @@ -630,8 +703,8 @@ func TestFilterFinishRouterMulti(t *testing.T) { url := "/finishRouterMulti" mux := NewControllerRegister() - mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, false) - mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, false) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter1, WithReturnOnOutput(false)) + mux.InsertFilter(url, FinishRouter, beegoFinishRouter2, WithReturnOnOutput(false)) mux.Get(url, beegoFilterFunc) @@ -656,17 +729,14 @@ func beegoBeforeRouter1(ctx *context.Context) { ctx.WriteString("|BeforeRouter1") } - func beegoBeforeExec1(ctx *context.Context) { ctx.WriteString("|BeforeExec1") } - func beegoAfterExec1(ctx *context.Context) { ctx.WriteString("|AfterExec1") } - func beegoFinishRouter1(ctx *context.Context) { ctx.WriteString("|FinishRouter1") } @@ -709,3 +779,375 @@ func TestYAMLPrepare(t *testing.T) { t.Errorf(w.Body.String()) } } + +func TestRouterEntityTooLargeCopyBody(t *testing.T) { + _MaxMemory := BConfig.MaxMemory + _CopyRequestBody := BConfig.CopyRequestBody + BConfig.CopyRequestBody = true + BConfig.MaxMemory = 20 + + BeeApp.Cfg.CopyRequestBody = true + BeeApp.Cfg.MaxMemory = 20 + b := bytes.NewBuffer([]byte("barbarbarbarbarbarbarbarbarbar")) + r, _ := http.NewRequest("POST", "/user/123", b) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.Post("/user/:id", func(ctx *context.Context) { + ctx.Output.Body([]byte(ctx.Input.Param(":id"))) + }) + handler.ServeHTTP(w, r) + + BConfig.CopyRequestBody = _CopyRequestBody + BConfig.MaxMemory = _MaxMemory + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("TestRouterRequestEntityTooLarge can't run") + } +} + +func TestRouterSessionSet(t *testing.T) { + oldGlobalSessionOn := BConfig.WebConfig.Session.SessionOn + defer func() { + BConfig.WebConfig.Session.SessionOn = oldGlobalSessionOn + }() + + // global sessionOn = false, router sessionOn = false + r, _ := http.NewRequest("GET", "/user", nil) + w := httptest.NewRecorder() + handler := NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = false, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + BConfig.WebConfig.Session.SessionOn = true + if err := registerSession(); err != nil { + t.Errorf("register session failed, error: %s", err.Error()) + } + // global sessionOn = true, router sessionOn = false + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(false)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") != "" { + t.Errorf("TestRotuerSessionSet failed") + } + + // global sessionOn = true, router sessionOn = true + r, _ = http.NewRequest("GET", "/user", nil) + w = httptest.NewRecorder() + handler = NewControllerRegister() + handler.Add("/user", &TestController{}, WithRouterMethods(&TestController{}, "get:Get"), + WithRouterSessionOn(true)) + handler.ServeHTTP(w, r) + if w.Header().Get("Set-Cookie") == "" { + t.Errorf("TestRotuerSessionSet failed") + } +} + +func TestRouterCtrlGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlGet("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlGet can't run") + } +} + +func TestRouterCtrlPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPost("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlPost can't run") + } +} + +func TestRouterCtrlHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlHead("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlHead can't run") + } +} + +func TestRouterCtrlPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPut("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlPut can't run") + } +} + +func TestRouterCtrlPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPatch("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlPatch can't run") + } +} + +func TestRouterCtrlDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlDelete("/user", ExampleController.Ping) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlDelete can't run") + } +} + +func TestRouterCtrlAny(t *testing.T) { + handler := NewControllerRegister() + handler.CtrlAny("/user", ExampleController.Ping) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestRouterCtrlAny can't run, get the response is " + w.Body.String()) + } + } +} + +func TestRouterCtrlGetPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlGet("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlGetPointerMethod can't run") + } +} + +func TestRouterCtrlPostPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPost("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlPostPointerMethod can't run") + } +} + +func TestRouterCtrlHeadPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlHead("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlHeadPointerMethod can't run") + } +} + +func TestRouterCtrlPutPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPut("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlPutPointerMethod can't run") + } +} + +func TestRouterCtrlPatchPointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlPatch("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlPatchPointerMethod can't run") + } +} + +func TestRouterCtrlDeletePointerMethod(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + handler := NewControllerRegister() + handler.CtrlDelete("/user", (*ExampleController).PingPointer) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlDeletePointerMethod can't run") + } +} + +func TestRouterCtrlAnyPointerMethod(t *testing.T) { + handler := NewControllerRegister() + handler.CtrlAny("/user", (*ExampleController).PingPointer) + + for method := range HTTPMETHOD { + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/user", nil) + handler.ServeHTTP(w, r) + if w.Body.String() != examplePointerBody { + t.Errorf("TestRouterCtrlAnyPointerMethod can't run, get the response is " + w.Body.String()) + } + } +} + +func TestRouterAddRouterMethodPanicInvalidMethod(t *testing.T) { + method := "some random method" + message := "not support http method: " + strings.ToUpper(method) + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicInvalidMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController.Ping) +} + +func TestRouterAddRouterMethodPanicNotAMethod(t *testing.T) { + method := http.MethodGet + message := "not a method" + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotAMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController{}) +} + +func TestRouterAddRouterMethodPanicNotPublicMethod(t *testing.T) { + method := http.MethodGet + message := "ping is not a public method" + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotPublicMethod failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", ExampleController.ping) +} + +func TestRouterAddRouterMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", TestControllerWithInterface.Ping) +} + +func TestRouterAddRouterPointerMethodPanicNotImplementInterface(t *testing.T) { + method := http.MethodGet + message := "web.TestControllerWithInterface is not implemented ControllerInterface" + defer func() { + err := recover() + if err != nil { // 产生了panic异常 + errStr, ok := err.(string) + if ok && errStr == message { + return + } + } + t.Errorf(fmt.Sprintf("TestRouterAddRouterPointerMethodPanicNotImplementInterface failed: %v", err)) + }() + + handler := NewControllerRegister() + handler.AddRouterMethod(method, "/user", (*TestControllerWithInterface).PingPointer) +} + +func TestGetAllControllerInfo(t *testing.T) { + handler := NewControllerRegister() + handler.Add("/level1", &TestController{}, WithRouterMethods(&TestController{}, "get:Get")) + handler.Add("/level1/level2", &TestController{}, WithRouterMethods(&TestController{}, "get:Get")) + handler.Add("/:name1", &TestController{}, WithRouterMethods(&TestController{}, "post:Post")) + + var actualPatterns []string + var actualMethods []string + for _, controllerInfo := range handler.GetAllControllerInfo() { + actualPatterns = append(actualPatterns, controllerInfo.GetPattern()) + for _, httpMethod := range controllerInfo.GetMethod() { + actualMethods = append(actualMethods, httpMethod) + } + } + sort.Strings(actualPatterns) + expectedPatterns := []string{"/level1", "/level1/level2", "/:name1"} + sort.Strings(expectedPatterns) + if !reflect.DeepEqual(actualPatterns, expectedPatterns) { + t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedPatterns, actualPatterns) + } + + sort.Strings(actualMethods) + expectedMethods := []string{"Get", "Get", "Post"} + sort.Strings(expectedMethods) + if !reflect.DeepEqual(actualMethods, expectedMethods) { + t.Errorf("ControllerInfo.GetMethod expected %#v, but %#v got", expectedMethods, actualMethods) + } +} diff --git a/server/web/server.go b/server/web/server.go new file mode 100644 index 0000000000..8c80b20e16 --- /dev/null +++ b/server/web/server.go @@ -0,0 +1,980 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "net/http" + "net/http/fcgi" + "os" + "path" + "strconv" + "strings" + "text/template" + "time" + + "golang.org/x/crypto/acme/autocert" + + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/core/utils" + beecontext "github.com/beego/beego/v2/server/web/context" + "github.com/beego/beego/v2/server/web/grace" +) + +// BeeApp is an application instance +// If you are using single server, you could use this +// But if you need multiple servers, do not use this +var BeeApp *HttpServer + +func init() { + // create beego application + BeeApp = NewHttpSever() +} + +// HttpServer defines beego application with a new PatternServeMux. +type HttpServer struct { + Handlers *ControllerRegister + Server *http.Server + Cfg *Config + LifeCycleCallbacks []LifeCycleCallback +} + +// NewHttpSever returns a new beego application. +// this method will use the BConfig as the configure to create HttpServer +// Be careful that when you update BConfig, the server's Cfg will be updated too +func NewHttpSever() *HttpServer { + return NewHttpServerWithCfg(BConfig) +} + +// NewHttpServerWithCfg will create an sever with specific cfg +func NewHttpServerWithCfg(cfg *Config) *HttpServer { + cr := NewControllerRegisterWithCfg(cfg) + app := &HttpServer{ + Handlers: cr, + Server: &http.Server{}, + Cfg: cfg, + } + + return app +} + +// MiddleWare function for http.Handler +type MiddleWare func(http.Handler) http.Handler + +// LifeCycleCallback configures callback. +// Developer can implement this interface to add custom logic to server lifecycle +type LifeCycleCallback interface { + AfterStart(app *HttpServer) + BeforeShutdown(app *HttpServer) +} + +// Run beego application. +func (app *HttpServer) Run(addr string, mws ...MiddleWare) { + initBeforeHTTPRun() + + // init... + app.initAddr(addr) + app.Handlers.Init() + + addr = app.Cfg.Listen.HTTPAddr + + if app.Cfg.Listen.HTTPPort != 0 { + addr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPAddr, app.Cfg.Listen.HTTPPort) + } + + var ( + err error + l net.Listener + endRunning = make(chan bool, 1) + ) + + // run cgi server + if app.Cfg.Listen.EnableFcgi { + for _, lifeCycleCallback := range app.LifeCycleCallbacks { + lifeCycleCallback.AfterStart(app) + } + if app.Cfg.Listen.EnableStdIo { + if err = fcgi.Serve(nil, app.Handlers); err == nil { // standard I/O + logs.Info("Use FCGI via standard I/O") + } else { + logs.Critical("Cannot use FCGI via standard I/O", err) + } + for _, lifeCycleCallback := range app.LifeCycleCallbacks { + lifeCycleCallback.BeforeShutdown(app) + } + return + } + if app.Cfg.Listen.HTTPPort == 0 { + // remove the Socket file before start + if utils.FileExists(addr) { + os.Remove(addr) + } + l, err = net.Listen("unix", addr) + } else { + l, err = net.Listen("tcp", addr) + } + if err != nil { + logs.Critical("Listen for Fcgi: ", err) + } + if err = fcgi.Serve(l, app.Handlers); err != nil { + logs.Critical("fcgi.Serve: ", err) + } + for _, lifeCycleCallback := range app.LifeCycleCallbacks { + lifeCycleCallback.BeforeShutdown(app) + } + return + } + + app.Server.Handler = app.Handlers + for i := len(mws) - 1; i >= 0; i-- { + if mws[i] == nil { + continue + } + app.Server.Handler = mws[i](app.Server.Handler) + } + app.Server.ReadTimeout = time.Duration(app.Cfg.Listen.ServerTimeOut) * time.Second + app.Server.WriteTimeout = time.Duration(app.Cfg.Listen.ServerTimeOut) * time.Second + app.Server.ErrorLog = logs.GetLogger("HTTP") + + // run graceful mode + if app.Cfg.Listen.Graceful { + var opts []grace.ServerOption + for _, lifeCycleCallback := range app.LifeCycleCallbacks { + lifeCycleCallbackDup := lifeCycleCallback + opts = append(opts, grace.WithShutdownCallback(func() { + lifeCycleCallbackDup.BeforeShutdown(app) + })) + } + + httpsAddr := app.Cfg.Listen.HTTPSAddr + app.Server.Addr = httpsAddr + if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if app.Cfg.Listen.HTTPSPort != 0 { + httpsAddr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort) + app.Server.Addr = httpsAddr + } + server := grace.NewServer(httpsAddr, app.Server.Handler, opts...) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + var ln net.Listener + if app.Cfg.Listen.EnableMutualHTTPS { + if ln, err = server.ListenMutualTLS(app.Cfg.Listen.HTTPSCertFile, + app.Cfg.Listen.HTTPSKeyFile, + app.Cfg.Listen.TrustCaFile); err != nil { + logs.Critical("ListenMutualTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + return + } + } else { + if app.Cfg.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(app.Cfg.Listen.Domains...), + Cache: autocert.DirCache(app.Cfg.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", "" + } + if ln, err = server.ListenTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + return + } + } + for _, callback := range app.LifeCycleCallbacks { + callback.AfterStart(app) + } + if err = server.ServeTLS(ln); err != nil { + logs.Critical("ServeTLS: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + endRunning <- true + }() + } + if app.Cfg.Listen.EnableHTTP { + go func() { + server := grace.NewServer(addr, app.Server.Handler, opts...) + server.Server.ReadTimeout = app.Server.ReadTimeout + server.Server.WriteTimeout = app.Server.WriteTimeout + if app.Cfg.Listen.ListenTCP4 { + server.Network = "tcp4" + } + ln, err := net.Listen(server.Network, server.Addr) + logs.Info("graceful http server Running on http://%s", server.Addr) + if err != nil { + logs.Critical("Listen for HTTP[graceful mode]: ", err) + endRunning <- true + return + } + for _, callback := range app.LifeCycleCallbacks { + callback.AfterStart(app) + } + if err := server.ServeWithListener(ln); err != nil { + logs.Critical("ServeWithListener: ", err, fmt.Sprintf("%d", os.Getpid())) + time.Sleep(100 * time.Microsecond) + } + endRunning <- true + }() + } + <-endRunning + return + } + + // run normal mode + if app.Cfg.Listen.EnableHTTPS || app.Cfg.Listen.EnableMutualHTTPS { + go func() { + time.Sleep(1000 * time.Microsecond) + if app.Cfg.Listen.HTTPSPort != 0 { + app.Server.Addr = fmt.Sprintf("%s:%d", app.Cfg.Listen.HTTPSAddr, app.Cfg.Listen.HTTPSPort) + } else if app.Cfg.Listen.EnableHTTP { + logs.Info("Start https server error, conflict with http. Please reset https port") + return + } + logs.Info("https server Running on https://%s", app.Server.Addr) + if app.Cfg.Listen.AutoTLS { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(app.Cfg.Listen.Domains...), + Cache: autocert.DirCache(app.Cfg.Listen.TLSCacheDir), + } + app.Server.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate} + app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile = "", "" + } else if app.Cfg.Listen.EnableMutualHTTPS { + pool := x509.NewCertPool() + data, err := os.ReadFile(app.Cfg.Listen.TrustCaFile) + if err != nil { + logs.Info("MutualHTTPS should provide TrustCaFile") + return + } + pool.AppendCertsFromPEM(data) + app.Server.TLSConfig = &tls.Config{ + ClientCAs: pool, + ClientAuth: tls.ClientAuthType(app.Cfg.Listen.ClientAuth), + } + } + if err := app.Server.ListenAndServeTLS(app.Cfg.Listen.HTTPSCertFile, app.Cfg.Listen.HTTPSKeyFile); err != nil { + logs.Critical("ListenAndServeTLS: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + }() + } + if app.Cfg.Listen.EnableHTTP { + go func() { + app.Server.Addr = addr + logs.Info("http server Running on http://%s", app.Server.Addr) + if app.Cfg.Listen.ListenTCP4 { + ln, err := net.Listen("tcp4", app.Server.Addr) + if err != nil { + logs.Critical("Listen for HTTP[normal mode]: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + if err = app.Server.Serve(ln); err != nil { + logs.Critical("Serve: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + return + } + } else { + if err := app.Server.ListenAndServe(); err != nil { + logs.Critical("ListenAndServe: ", err) + time.Sleep(100 * time.Microsecond) + endRunning <- true + } + } + }() + } + <-endRunning +} + +// Router see HttpServer.Router +func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *HttpServer { + return RouterWithOpts(rootpath, c, WithRouterMethods(c, mappingMethods...)) +} + +func RouterWithOpts(rootpath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + return BeeApp.RouterWithOpts(rootpath, c, opts...) +} + +// Router adds a patterned controller handler to BeeApp. +// it's an alias method of HttpServer.Router. +// usage: +// +// simple router +// beego.Router("/admin", &admin.UserController{}) +// beego.Router("/admin/index", &admin.ArticleController{}) +// +// regex router +// +// beego.Router("/api/:id([0-9]+)", &controllers.RController{}) +// +// custom rules +// beego.Router("/api/list",&RestController{},"*:ListFood") +// beego.Router("/api/create",&RestController{},"post:CreateFood") +// beego.Router("/api/update",&RestController{},"put:UpdateFood") +// beego.Router("/api/delete",&RestController{},"delete:DeleteFood") +func (app *HttpServer) Router(rootPath string, c ControllerInterface, mappingMethods ...string) *HttpServer { + return app.RouterWithOpts(rootPath, c, WithRouterMethods(c, mappingMethods...)) +} + +func (app *HttpServer) RouterWithOpts(rootPath string, c ControllerInterface, opts ...ControllerOption) *HttpServer { + app.Handlers.Add(rootPath, c, opts...) + return app +} + +// UnregisterFixedRoute see HttpServer.UnregisterFixedRoute +func UnregisterFixedRoute(fixedRoute string, method string) *HttpServer { + return BeeApp.UnregisterFixedRoute(fixedRoute, method) +} + +// UnregisterFixedRoute unregisters the route with the specified fixedRoute. It is particularly useful +// in web applications that inherit most routes from a base webapp via the underscore +// import, and aim to overwrite only certain paths. +// The method parameter can be empty or "*" for all HTTP methods, or a particular +// method type (e.g. "GET" or "POST") for selective removal. +// +// Usage (replace "GET" with "*" for all methods): +// +// beego.UnregisterFixedRoute("/yourpreviouspath", "GET") +// beego.Router("/yourpreviouspath", yourControllerAddress, "get:GetNewPage") +func (app *HttpServer) UnregisterFixedRoute(fixedRoute string, method string) *HttpServer { + subPaths := splitPath(fixedRoute) + if method == "" || method == "*" { + for m := range HTTPMETHOD { + if _, ok := app.Handlers.routers[m]; !ok { + continue + } + if app.Handlers.routers[m].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(app.Handlers.routers[m]) + continue + } + findAndRemoveTree(subPaths, app.Handlers.routers[m], m) + } + return app + } + // Single HTTP method + um := strings.ToUpper(method) + if _, ok := app.Handlers.routers[um]; ok { + if app.Handlers.routers[um].prefix == strings.Trim(fixedRoute, "/ ") { + findAndRemoveSingleTree(app.Handlers.routers[um]) + return app + } + findAndRemoveTree(subPaths, app.Handlers.routers[um], um) + } + return app +} + +func findAndRemoveTree(paths []string, entryPointTree *Tree, method string) { + for i := range entryPointTree.fixrouters { + if entryPointTree.fixrouters[i].prefix == paths[0] { + if len(paths) == 1 { + if len(entryPointTree.fixrouters[i].fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.fixrouters[i].leaves) > 0 { + entryPointTree.fixrouters[i].leaves[0] = nil + entryPointTree.fixrouters[i].leaves = entryPointTree.fixrouters[i].leaves[1:] + } + } else { + // Remove the *Tree from the fixrouters slice + entryPointTree.fixrouters[i] = nil + + if i == len(entryPointTree.fixrouters)-1 { + entryPointTree.fixrouters = entryPointTree.fixrouters[:i] + } else { + entryPointTree.fixrouters = append(entryPointTree.fixrouters[:i], entryPointTree.fixrouters[i+1:len(entryPointTree.fixrouters)]...) + } + } + return + } + findAndRemoveTree(paths[1:], entryPointTree.fixrouters[i], method) + } + } +} + +func findAndRemoveSingleTree(entryPointTree *Tree) { + if entryPointTree == nil { + return + } + if len(entryPointTree.fixrouters) > 0 { + // If the route had children subtrees, remove just the functional leaf, + // to allow children to function as before + if len(entryPointTree.leaves) > 0 { + entryPointTree.leaves[0] = nil + entryPointTree.leaves = entryPointTree.leaves[1:] + } + } +} + +// Include see HttpServer.Include +func Include(cList ...ControllerInterface) *HttpServer { + return BeeApp.Include(cList...) +} + +// Include will generate router file in the router/xxx.go from the controller's comments +// usage: +// beego.Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) +// +// type BankAccount struct{ +// beego.Controller +// } +// +// register the function +// +// func (b *BankAccount)Mapping(){ +// b.Mapping("ShowAccount" , b.ShowAccount) +// b.Mapping("ModifyAccount", b.ModifyAccount) +// } +// +// //@router /account/:id [get] +// +// func (b *BankAccount) ShowAccount(){ +// //logic +// } +// +// //@router /account/:id [post] +// +// func (b *BankAccount) ModifyAccount(){ +// //logic +// } +// +// the comments @router url methodlist +// url support all the function Router's pattern +// methodlist [get post head put delete options *] +func (app *HttpServer) Include(cList ...ControllerInterface) *HttpServer { + app.Handlers.Include(cList...) + return app +} + +// RESTRouter see HttpServer.RESTRouter +func RESTRouter(rootpath string, c ControllerInterface) *HttpServer { + return BeeApp.RESTRouter(rootpath, c) +} + +// RESTRouter adds a restful controller handler to BeeApp. +// its' controller implements beego.ControllerInterface and +// defines a param "pattern/:objectId" to visit each resource. +func (app *HttpServer) RESTRouter(rootpath string, c ControllerInterface) *HttpServer { + app.Router(rootpath, c) + app.Router(path.Join(rootpath, ":objectId"), c) + return app +} + +// AutoRouter see HttpServer.AutoRouter +func AutoRouter(c ControllerInterface) *HttpServer { + return BeeApp.AutoRouter(c) +} + +// AutoRouter adds defined controller handler to BeeApp. +// it's same to HttpServer.AutoRouter. +// if beego.AddAuto(&MainController{}) and MainController has methods List and Page, +// visit the url /main/list to exec List function or /main/page to exec Page function. +func (app *HttpServer) AutoRouter(c ControllerInterface) *HttpServer { + app.Handlers.AddAuto(c) + return app +} + +// AutoPrefix see HttpServer.AutoPrefix +func AutoPrefix(prefix string, c ControllerInterface) *HttpServer { + return BeeApp.AutoPrefix(prefix, c) +} + +// AutoPrefix adds controller handler to BeeApp with prefix. +// it's same to HttpServer.AutoRouterWithPrefix. +// if beego.AutoPrefix("/admin",&MainController{}) and MainController has methods List and Page, +// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function. +func (app *HttpServer) AutoPrefix(prefix string, c ControllerInterface) *HttpServer { + app.Handlers.AddAutoPrefix(prefix, c) + return app +} + +// CtrlGet see HttpServer.CtrlGet +func CtrlGet(rootpath string, f interface{}) { + BeeApp.CtrlGet(rootpath, f) +} + +// CtrlGet used to register router for CtrlGet method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlGet("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlGet(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlGet(rootpath, f) + return app +} + +// CtrlPost see HttpServer.CtrlGet +func CtrlPost(rootpath string, f interface{}) { + BeeApp.CtrlPost(rootpath, f) +} + +// CtrlPost used to register router for CtrlPost method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPost("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlPost(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlPost(rootpath, f) + return app +} + +// CtrlHead see HttpServer.CtrlHead +func CtrlHead(rootpath string, f interface{}) { + BeeApp.CtrlHead(rootpath, f) +} + +// CtrlHead used to register router for CtrlHead method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlHead("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlHead(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlHead(rootpath, f) + return app +} + +// CtrlPut see HttpServer.CtrlPut +func CtrlPut(rootpath string, f interface{}) { + BeeApp.CtrlPut(rootpath, f) +} + +// CtrlPut used to register router for CtrlPut method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPut("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlPut(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlPut(rootpath, f) + return app +} + +// CtrlPatch see HttpServer.CtrlPatch +func CtrlPatch(rootpath string, f interface{}) { + BeeApp.CtrlPatch(rootpath, f) +} + +// CtrlPatch used to register router for CtrlPatch method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlPatch("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlPatch(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlPatch(rootpath, f) + return app +} + +// CtrlDelete see HttpServer.CtrlDelete +func CtrlDelete(rootpath string, f interface{}) { + BeeApp.CtrlDelete(rootpath, f) +} + +// CtrlDelete used to register router for CtrlDelete method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlDelete("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlDelete(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlDelete(rootpath, f) + return app +} + +// CtrlOptions see HttpServer.CtrlOptions +func CtrlOptions(rootpath string, f interface{}) { + BeeApp.CtrlOptions(rootpath, f) +} + +// CtrlOptions used to register router for CtrlOptions method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlOptions("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlOptions(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlOptions(rootpath, f) + return app +} + +// CtrlAny see HttpServer.CtrlAny +func CtrlAny(rootpath string, f interface{}) { + BeeApp.CtrlAny(rootpath, f) +} + +// CtrlAny used to register router for CtrlAny method +// usage: +// +// type MyController struct { +// web.Controller +// } +// func (m MyController) Ping() { +// m.Ctx.Output.Body([]byte("hello world")) +// } +// +// CtrlAny("/api/:id", MyController.Ping) +func (app *HttpServer) CtrlAny(rootpath string, f interface{}) *HttpServer { + app.Handlers.CtrlAny(rootpath, f) + return app +} + +// Get see HttpServer.Get +func Get(rootpath string, f HandleFunc) *HttpServer { + return BeeApp.Get(rootpath, f) +} + +// Get used to register router for Get method +// usage: +// +// beego.Get("/", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (app *HttpServer) Get(rootpath string, f HandleFunc) *HttpServer { + app.Handlers.Get(rootpath, f) + return app +} + +// Post see HttpServer.Post +func Post(rootpath string, f HandleFunc) *HttpServer { + return BeeApp.Post(rootpath, f) +} + +// Post used to register router for Post method +// usage: +// +// beego.Post("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (app *HttpServer) Post(rootpath string, f HandleFunc) *HttpServer { + app.Handlers.Post(rootpath, f) + return app +} + +// Delete see HttpServer.Delete +func Delete(rootpath string, f HandleFunc) *HttpServer { + return BeeApp.Delete(rootpath, f) +} + +// Delete used to register router for Delete method +// usage: +// +// beego.Delete("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (app *HttpServer) Delete(rootpath string, f HandleFunc) *HttpServer { + app.Handlers.Delete(rootpath, f) + return app +} + +// Put see HttpServer.Put +func Put(rootpath string, f HandleFunc) *HttpServer { + return BeeApp.Put(rootpath, f) +} + +// Put used to register router for Put method +// usage: +// +// beego.Put("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (app *HttpServer) Put(rootpath string, f HandleFunc) *HttpServer { + app.Handlers.Put(rootpath, f) + return app +} + +// Head see HttpServer.Head +func Head(rootpath string, f HandleFunc) *HttpServer { + return BeeApp.Head(rootpath, f) +} + +// Head used to register router for Head method +// usage: +// +// beego.Head("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (app *HttpServer) Head(rootpath string, f HandleFunc) *HttpServer { + app.Handlers.Head(rootpath, f) + return app +} + +// Options see HttpServer.Options +func Options(rootpath string, f HandleFunc) *HttpServer { + BeeApp.Handlers.Options(rootpath, f) + return BeeApp +} + +// Options used to register router for Options method +// usage: +// +// beego.Options("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (app *HttpServer) Options(rootpath string, f HandleFunc) *HttpServer { + app.Handlers.Options(rootpath, f) + return app +} + +// Patch see HttpServer.Patch +func Patch(rootpath string, f HandleFunc) *HttpServer { + return BeeApp.Patch(rootpath, f) +} + +// Patch used to register router for Patch method +// usage: +// +// beego.Patch("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (app *HttpServer) Patch(rootpath string, f HandleFunc) *HttpServer { + app.Handlers.Patch(rootpath, f) + return app +} + +// Any see HttpServer.Any +func Any(rootpath string, f HandleFunc) *HttpServer { + return BeeApp.Any(rootpath, f) +} + +// Any used to register router for all methods +// usage: +// +// beego.Any("/api", func(ctx *context.Context){ +// ctx.Output.Body("hello world") +// }) +func (app *HttpServer) Any(rootpath string, f HandleFunc) *HttpServer { + app.Handlers.Any(rootpath, f) + return app +} + +// Handler see HttpServer.Handler +func Handler(rootpath string, h http.Handler, options ...interface{}) *HttpServer { + return BeeApp.Handler(rootpath, h, options...) +} + +// Handler used to register a Handler router +// usage: +// +// beego.Handler("/api", http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { +// fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) +// })) +func (app *HttpServer) Handler(rootpath string, h http.Handler, options ...interface{}) *HttpServer { + app.Handlers.Handler(rootpath, h, options...) + return app +} + +// InsertFilter see HttpServer.InsertFilter +func InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *HttpServer { + return BeeApp.InsertFilter(pattern, pos, filter, opts...) +} + +// InsertFilter adds a FilterFunc with pattern condition and action constant. +// The pos means action constant including +// beego.BeforeStatic, beego.BeforeRouter, beego.BeforeExec, beego.AfterExec and beego.FinishRouter. +// The bool params is for setting the returnOnOutput value (false allows multiple filters to execute) +func (app *HttpServer) InsertFilter(pattern string, pos int, filter FilterFunc, opts ...FilterOpt) *HttpServer { + app.Handlers.InsertFilter(pattern, pos, filter, opts...) + return app +} + +// InsertFilterChain see HttpServer.InsertFilterChain +func InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *HttpServer { + return BeeApp.InsertFilterChain(pattern, filterChain, opts...) +} + +// InsertFilterChain adds a FilterFunc built by filterChain. +// This filter will be executed before all filters. +// the filter's behavior like stack's behavior +// and the last filter is serving the http request +func (app *HttpServer) InsertFilterChain(pattern string, filterChain FilterChain, opts ...FilterOpt) *HttpServer { + app.Handlers.InsertFilterChain(pattern, filterChain, opts...) + return app +} + +func (app *HttpServer) initAddr(addr string) { + strs := strings.Split(addr, ":") + if len(strs) > 0 && strs[0] != "" { + app.Cfg.Listen.HTTPAddr = strs[0] + app.Cfg.Listen.Domains = []string{strs[0]} + } + if len(strs) > 1 && strs[1] != "" { + app.Cfg.Listen.HTTPPort, _ = strconv.Atoi(strs[1]) + } +} + +func (app *HttpServer) LogAccess(ctx *beecontext.Context, startTime *time.Time, statusCode int) { + // Skip logging if AccessLogs config is false + if !app.Cfg.Log.AccessLogs { + return + } + // Skip logging static requests unless EnableStaticLogs config is true + if !app.Cfg.Log.EnableStaticLogs && DefaultAccessLogFilter.Filter(ctx) { + return + } + var ( + requestTime time.Time + elapsedTime time.Duration + r = ctx.Request + ) + if startTime != nil { + requestTime = *startTime + elapsedTime = time.Since(*startTime) + } + record := &logs.AccessLogRecord{ + RemoteAddr: ctx.Input.IP(), + RequestTime: requestTime, + RequestMethod: r.Method, + Request: fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), + ServerProtocol: r.Proto, + Host: r.Host, + Status: statusCode, + ElapsedTime: elapsedTime, + HTTPReferrer: r.Header.Get("Referer"), + HTTPUserAgent: r.Header.Get("User-Agent"), + RemoteUser: r.Header.Get("Remote-User"), + BodyBytesSent: r.ContentLength, + } + logs.AccessLog(record, app.Cfg.Log.AccessLogsFormat) +} + +// PrintTree prints all registered routers. +func (app *HttpServer) PrintTree() M { + var ( + content = M{} + methods = []string{} + methodsData = make(M) + ) + for method, t := range app.Handlers.routers { + + resultList := new([][]string) + + printTree(resultList, t) + + methods = append(methods, template.HTMLEscapeString(method)) + methodsData[template.HTMLEscapeString(method)] = resultList + } + + content["Data"] = methodsData + content["Methods"] = methods + return content +} + +func printTree(resultList *[][]string, t *Tree) { + for _, tr := range t.fixrouters { + printTree(resultList, tr) + } + if t.wildcard != nil { + printTree(resultList, t.wildcard) + } + for _, l := range t.leaves { + if v, ok := l.runObject.(*ControllerInfo); ok { + if v.routerType == routerTypeBeego { + result := []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + template.HTMLEscapeString(v.controllerType.String()), + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeRESTFul { + result := []string{ + template.HTMLEscapeString(v.pattern), + template.HTMLEscapeString(fmt.Sprintf("%s", v.methods)), + "", + } + *resultList = append(*resultList, result) + } else if v.routerType == routerTypeHandler { + result := []string{ + template.HTMLEscapeString(v.pattern), + "", + "", + } + *resultList = append(*resultList, result) + } + } + } +} + +func (app *HttpServer) reportFilter() M { + filterTypeData := make(M) + // filterTypes := []string{} + if app.Handlers.enableFilter { + // var filterType string + for k, fr := range map[int]string{ + BeforeStatic: "Before Static", + BeforeRouter: "Before Router", + BeforeExec: "Before Exec", + AfterExec: "After Exec", + FinishRouter: "Finish Router", + } { + if bf := app.Handlers.filters[k]; len(bf) > 0 { + resultList := new([][]string) + for _, f := range bf { + result := []string{ + // void xss + template.HTMLEscapeString(f.pattern), + template.HTMLEscapeString(utils.GetFuncName(f.filterFunc)), + } + *resultList = append(*resultList, result) + } + filterTypeData[fr] = resultList + } + } + } + + return filterTypeData +} diff --git a/server/web/server_test.go b/server/web/server_test.go new file mode 100644 index 0000000000..1ddf217e14 --- /dev/null +++ b/server/web/server_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewHttpServerWithCfg(t *testing.T) { + BConfig.AppName = "Before" + svr := NewHttpServerWithCfg(BConfig) + svr.Cfg.AppName = "hello" + assert.Equal(t, "hello", BConfig.AppName) +} + +func TestServerCtrlGet(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/user", nil) + w := httptest.NewRecorder() + + CtrlGet("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlGet can't run") + } +} + +func TestServerCtrlPost(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "/user", nil) + w := httptest.NewRecorder() + + CtrlPost("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlPost can't run") + } +} + +func TestServerCtrlHead(t *testing.T) { + r, _ := http.NewRequest(http.MethodHead, "/user", nil) + w := httptest.NewRecorder() + + CtrlHead("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlHead can't run") + } +} + +func TestServerCtrlPut(t *testing.T) { + r, _ := http.NewRequest(http.MethodPut, "/user", nil) + w := httptest.NewRecorder() + + CtrlPut("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlPut can't run") + } +} + +func TestServerCtrlPatch(t *testing.T) { + r, _ := http.NewRequest(http.MethodPatch, "/user", nil) + w := httptest.NewRecorder() + + CtrlPatch("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlPatch can't run") + } +} + +func TestServerCtrlDelete(t *testing.T) { + r, _ := http.NewRequest(http.MethodDelete, "/user", nil) + w := httptest.NewRecorder() + + CtrlDelete("/user", ExampleController.Ping) + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlDelete can't run") + } +} + +func TestServerCtrlAny(t *testing.T) { + CtrlAny("/user", ExampleController.Ping) + + for method := range HTTPMETHOD { + r, _ := http.NewRequest(method, "/user", nil) + w := httptest.NewRecorder() + BeeApp.Handlers.ServeHTTP(w, r) + if w.Body.String() != exampleBody { + t.Errorf("TestServerCtrlAny can't run") + } + } +} diff --git a/server/web/session/README.md b/server/web/session/README.md new file mode 100644 index 0000000000..ea51360ee4 --- /dev/null +++ b/server/web/session/README.md @@ -0,0 +1,113 @@ +session +============== + +session is a Go session manager. It can use many session providers. Just like the `database/sql` +and `database/sql/driver`. + +## How to install? + + go get github.com/beego/beego/v2/server/web/session + +## What providers are supported? + +As of now this session manager support memory, file, Redis and MySQL. + +## How to use it? + +First you must import it + + import ( + "github.com/beego/beego/v2/server/web/session" + ) + +Then in you web app init the global session manager + + var globalSessions *session.Manager + +* Use **memory** as provider: + + func init() { + globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) + go globalSessions.GC() + } + +* Use **file** as provider, the last param is the path where you want file to be stored: + + func init() { + globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`) + go globalSessions.GC() + } + +* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: + + func init() { + globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`) + go globalSessions.GC() + } + +* Use **MySQL** as provider, the last param is the DSN, learn more + from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name): + + func init() { + globalSessions, _ = session.NewManager( + "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`) + go globalSessions.GC() + } + +* Use **Cookie** as provider: + + func init() { + globalSessions, _ = session.NewManager( + "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) + go globalSessions.GC() + } + +Finally in the handlerfunc you can use it like this + + func login(w http.ResponseWriter, r *http.Request) { + sess := globalSessions.SessionStart(w, r) + defer sess.SessionRelease(w) + username := sess.Get("username") + fmt.Println(username) + if r.Method == "GET" { + t, _ := template.ParseFiles("login.gtpl") + t.Execute(w, nil) + } else { + fmt.Println("username:", r.Form["username"]) + sess.Set("username", r.Form["username"]) + fmt.Println("password:", r.Form["password"]) + } + } + +## How to write own provider? + +When you develop a web app, maybe you want to write own provider because you must meet the requirements. + +Writing a provider is easy. You only need to define two struct types +(Store and Provider), which satisfy the interface definition. Maybe you will find the **memory** provider is a good +example. + + // Store contains all data for one session process with specific id. + type Store interface { + Set(ctx context.Context, key, value interface{}) error // Set set session value + Get(ctx context.Context, key interface{}) interface{} // Get get session value + Delete(ctx context.Context, key interface{}) error // Delete delete session value + SessionID(ctx context.Context) string // SessionID return current sessionID + SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) // SessionReleaseIfPresent release the resource & save data to provider & return the data when the session is present, not all implementation support this feature, you need to check if the specific implementation if support this feature. + SessionRelease(ctx context.Context, w http.ResponseWriter) // SessionRelease release the resource & save data to provider & return the data + Flush(ctx context.Context) error // Flush delete all data + } + + type Provider interface { + SessionInit(gclifetime int64, config string) error + SessionRead(sid string) (SessionStore, error) + SessionExist(sid string) (bool, error) + SessionRegenerate(oldsid, sid string) (SessionStore, error) + SessionDestroy(sid string) error + SessionAll() int //get all active session + SessionGC() + } + +## LICENSE + +BSD License http://creativecommons.org/licenses/BSD/ diff --git a/session/couchbase/sess_couchbase.go b/server/web/session/couchbase/sess_couchbase.go similarity index 62% rename from session/couchbase/sess_couchbase.go rename to server/web/session/couchbase/sess_couchbase.go index 707d042c5c..514bdf6d63 100644 --- a/session/couchbase/sess_couchbase.go +++ b/server/web/session/couchbase/sess_couchbase.go @@ -20,26 +20,28 @@ // // Usage: // import( -// _ "github.com/astaxie/beego/session/couchbase" -// "github.com/astaxie/beego/session" +// +// _ "github.com/beego/beego/v2/server/web/session/couchbase" +// "github.com/beego/beego/v2/server/web/session" +// // ) // // func init() { // globalSessions, _ = session.NewManager("couchbase", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"http://host:port/, Pool, Bucket"}``) // go globalSessions.GC() // } -// -// more docs: http://beego.me/docs/module/session.md package couchbase import ( + "context" + "encoding/json" "net/http" "strings" "sync" couchbase "github.com/couchbase/go-couchbase" - "github.com/astaxie/beego/session" + "github.com/beego/beego/v2/server/web/session" ) var couchbpder = &Provider{} @@ -56,22 +58,22 @@ type SessionStore struct { // Provider couchabse provided type Provider struct { maxlifetime int64 - savePath string - pool string - bucket string + SavePath string `json:"save_path"` + Pool string `json:"pool"` + Bucket string `json:"bucket"` b *couchbase.Bucket } -// Set value to couchabse session -func (cs *SessionStore) Set(key, value interface{}) error { +// Set value to couchbase session +func (cs *SessionStore) Set(ctx context.Context, key, value interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values[key] = value return nil } -// Get value from couchabse session -func (cs *SessionStore) Get(key interface{}) interface{} { +// Get value from couchbase session +func (cs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { cs.lock.RLock() defer cs.lock.RUnlock() if v, ok := cs.values[key]; ok { @@ -81,7 +83,7 @@ func (cs *SessionStore) Get(key interface{}) interface{} { } // Delete value in couchbase session by given key -func (cs *SessionStore) Delete(key interface{}) error { +func (cs *SessionStore) Delete(ctx context.Context, key interface{}) error { cs.lock.Lock() defer cs.lock.Unlock() delete(cs.values, key) @@ -89,7 +91,7 @@ func (cs *SessionStore) Delete(key interface{}) error { } // Flush Clean all values in couchbase session -func (cs *SessionStore) Flush() error { +func (cs *SessionStore) Flush(context.Context) error { cs.lock.Lock() defer cs.lock.Unlock() cs.values = make(map[interface{}]interface{}) @@ -97,15 +99,17 @@ func (cs *SessionStore) Flush() error { } // SessionID Get couchbase session store id -func (cs *SessionStore) SessionID() string { +func (cs *SessionStore) SessionID(context.Context) string { return cs.sid } // SessionRelease Write couchbase session with Gob string -func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { +func (cs *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { defer cs.b.Close() - - bo, err := session.EncodeGob(cs.values) + cs.lock.RLock() + values := cs.values + cs.lock.RUnlock() + bo, err := session.EncodeGob(values) if err != nil { return } @@ -113,18 +117,24 @@ func (cs *SessionStore) SessionRelease(w http.ResponseWriter) { cs.b.Set(cs.sid, int(cs.maxlifetime), bo) } +// SessionReleaseIfPresent is not supported now. +// If we want to use couchbase, we may refactor the code to use couchbase collection. +func (cs *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) { + cs.SessionRelease(c, w) +} + func (cp *Provider) getBucket() *couchbase.Bucket { - c, err := couchbase.Connect(cp.savePath) + c, err := couchbase.Connect(cp.SavePath) if err != nil { return nil } - pool, err := c.GetPool(cp.pool) + pool, err := c.GetPool(cp.Pool) if err != nil { return nil } - bucket, err := pool.GetBucket(cp.bucket) + bucket, err := pool.GetBucket(cp.Bucket) if err != nil { return nil } @@ -134,25 +144,38 @@ func (cp *Provider) getBucket() *couchbase.Bucket { // SessionInit init couchbase session // savepath like couchbase server REST/JSON URL -// e.g. http://host:port/, Pool, Bucket -func (cp *Provider) SessionInit(maxlifetime int64, savePath string) error { +// For v1.x e.g. http://host:port/, Pool, Bucket +// For v2.x, you should pass json string. +// e.g. { "save_path": "http://host:port/", "pool": "mypool", "bucket": "mybucket"} +func (cp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfg string) error { cp.maxlifetime = maxlifetime + cfg = strings.TrimSpace(cfg) + // we think this is v2.0, using json to init the session + if strings.HasPrefix(cfg, "{") { + return json.Unmarshal([]byte(cfg), cp) + } else { + return cp.initOldStyle(cfg) + } +} + +// initOldStyle keep compatible with v1.x +func (cp *Provider) initOldStyle(savePath string) error { configs := strings.Split(savePath, ",") if len(configs) > 0 { - cp.savePath = configs[0] + cp.SavePath = configs[0] } if len(configs) > 1 { - cp.pool = configs[1] + cp.Pool = configs[1] } if len(configs) > 2 { - cp.bucket = configs[2] + cp.Bucket = configs[2] } return nil } // SessionRead read couchbase session by sid -func (cp *Provider) SessionRead(sid string) (session.Store, error) { +func (cp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { cp.b = cp.getBucket() var ( @@ -178,21 +201,21 @@ func (cp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist Check couchbase session exist. -// it checkes sid exist or not. -func (cp *Provider) SessionExist(sid string) bool { +// it checks sid exist or not. +func (cp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { cp.b = cp.getBucket() defer cp.b.Close() var doc []byte if err := cp.b.Get(sid, &doc); err != nil || doc == nil { - return false + return false, err } - return true + return true, nil } // SessionRegenerate remove oldsid and use sid to generate new session -func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (cp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { cp.b = cp.getBucket() var doc []byte @@ -224,8 +247,8 @@ func (cp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) return cs, nil } -// SessionDestroy Remove bucket in this couchbase -func (cp *Provider) SessionDestroy(sid string) error { +// SessionDestroy Remove Bucket in this couchbase +func (cp *Provider) SessionDestroy(ctx context.Context, sid string) error { cp.b = cp.getBucket() defer cp.b.Close() @@ -234,11 +257,11 @@ func (cp *Provider) SessionDestroy(sid string) error { } // SessionGC Recycle -func (cp *Provider) SessionGC() { +func (cp *Provider) SessionGC(context.Context) { } // SessionAll return all active session -func (cp *Provider) SessionAll() int { +func (cp *Provider) SessionAll(context.Context) int { return 0 } diff --git a/server/web/session/couchbase/sess_couchbase_test.go b/server/web/session/couchbase/sess_couchbase_test.go new file mode 100644 index 0000000000..5959f9c3b7 --- /dev/null +++ b/server/web/session/couchbase/sess_couchbase_test.go @@ -0,0 +1,43 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package couchbase + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `http://host:port/,Pool,Bucket` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "http://host:port/", cp.SavePath) + assert.Equal(t, "Pool", cp.Pool) + assert.Equal(t, "Bucket", cp.Bucket) + assert.Equal(t, int64(12), cp.maxlifetime) + + savePath = ` +{ "save_path": "my save path", "pool": "mypool", "bucket": "mybucket"} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, "mypool", cp.Pool) + assert.Equal(t, "mybucket", cp.Bucket) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/ledis/ledis_session.go b/server/web/session/ledis/ledis_session.go new file mode 100644 index 0000000000..bf9efa6faa --- /dev/null +++ b/server/web/session/ledis/ledis_session.go @@ -0,0 +1,202 @@ +// Package ledis provide session Provider +package ledis + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/ledisdb/ledisdb/config" + "github.com/ledisdb/ledisdb/ledis" + + "github.com/beego/beego/v2/server/web/session" +) + +var ( + ledispder = &Provider{} + c *ledis.DB +) + +// SessionStore ledis session store +type SessionStore struct { + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in ledis session +func (ls *SessionStore) Set(ctx context.Context, key, value interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values[key] = value + return nil +} + +// Get value in ledis session +func (ls *SessionStore) Get(ctx context.Context, key interface{}) interface{} { + ls.lock.RLock() + defer ls.lock.RUnlock() + if v, ok := ls.values[key]; ok { + return v + } + return nil +} + +// Delete value in ledis session +func (ls *SessionStore) Delete(ctx context.Context, key interface{}) error { + ls.lock.Lock() + defer ls.lock.Unlock() + delete(ls.values, key) + return nil +} + +// Flush clear all values in ledis session +func (ls *SessionStore) Flush(context.Context) error { + ls.lock.Lock() + defer ls.lock.Unlock() + ls.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get ledis session id +func (ls *SessionStore) SessionID(context.Context) string { + return ls.sid +} + +// SessionRelease save session values to ledis +func (ls *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { + ls.lock.RLock() + values := ls.values + ls.lock.RUnlock() + b, err := session.EncodeGob(values) + if err != nil { + return + } + c.Set([]byte(ls.sid), b) + c.Expire([]byte(ls.sid), ls.maxlifetime) +} + +// SessionReleaseIfPresent is not supported now, because ledis has no this feature like SETXX or atomic operation. +// https://github.com/ledisdb/ledisdb/issues/251 +// https://github.com/ledisdb/ledisdb/issues/351 +func (ls *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) { + ls.SessionRelease(c, w) +} + +// Provider ledis session provider +type Provider struct { + maxlifetime int64 + SavePath string `json:"save_path"` + Db int `json:"db"` +} + +// SessionInit init ledis session +// savepath like ledis server saveDataPath,pool size +// v1.x e.g. 127.0.0.1:6379,100 +// v2.x you should pass a json string +// e.g. { "save_path": "my save path", "db": 100} +func (lp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { + var err error + lp.maxlifetime = maxlifetime + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err = json.Unmarshal([]byte(cfgStr), lp) + } else { + err = lp.initOldStyle(cfgStr) + } + + if err != nil { + return err + } + + cfg := new(config.Config) + cfg.DataDir = lp.SavePath + + var ledisInstance *ledis.Ledis + ledisInstance, err = ledis.Open(cfg) + if err != nil { + return err + } + c, err = ledisInstance.Select(lp.Db) + return err +} + +func (lp *Provider) initOldStyle(cfgStr string) error { + var err error + configs := strings.Split(cfgStr, ",") + if len(configs) == 1 { + lp.SavePath = configs[0] + } else if len(configs) == 2 { + lp.SavePath = configs[0] + lp.Db, err = strconv.Atoi(configs[1]) + } + return err +} + +// SessionRead read ledis session by sid +func (lp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { + var ( + kv map[interface{}]interface{} + err error + ) + + kvs, _ := c.Get([]byte(sid)) + + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob(kvs); err != nil { + return nil, err + } + } + + ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} + return ls, nil +} + +// SessionExist check ledis session exist by sid +func (lp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { + count, _ := c.Exists([]byte(sid)) + return count != 0, nil +} + +// SessionRegenerate generate new sid for ledis session +func (lp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + count, _ := c.Exists([]byte(sid)) + if count == 0 { + // oldsid doesn't exist, set the new sid directly + // ignore error here, since if it returns error + // the existed value will be 0 + c.Set([]byte(sid), []byte("")) + c.Expire([]byte(sid), lp.maxlifetime) + } else { + data, _ := c.Get([]byte(oldsid)) + c.Set([]byte(sid), data) + c.Expire([]byte(sid), lp.maxlifetime) + } + return lp.SessionRead(context.Background(), sid) +} + +// SessionDestroy delete ledis session by id +func (lp *Provider) SessionDestroy(ctx context.Context, sid string) error { + c.Del([]byte(sid)) + return nil +} + +// SessionGC Implement method, no used. +func (lp *Provider) SessionGC(context.Context) { +} + +// SessionAll return all active session +func (lp *Provider) SessionAll(context.Context) int { + return 0 +} + +func init() { + session.Register("ledis", ledispder) +} diff --git a/server/web/session/ledis/ledis_session_test.go b/server/web/session/ledis/ledis_session_test.go new file mode 100644 index 0000000000..1cfb3ed1a8 --- /dev/null +++ b/server/web/session/ledis/ledis_session_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ledis + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `http://host:port/,100` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "http://host:port/", cp.SavePath) + assert.Equal(t, 100, cp.Db) + assert.Equal(t, int64(12), cp.maxlifetime) + + savePath = ` +{ "save_path": "my save path", "db": 100} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 100, cp.Db) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/session/memcache/sess_memcache.go b/server/web/session/memcache/sess_memcache.go similarity index 69% rename from session/memcache/sess_memcache.go rename to server/web/session/memcache/sess_memcache.go index 85a2d81534..dbc1b8b1f0 100644 --- a/session/memcache/sess_memcache.go +++ b/server/web/session/memcache/sess_memcache.go @@ -20,30 +20,33 @@ // // Usage: // import( -// _ "github.com/astaxie/beego/session/memcache" -// "github.com/astaxie/beego/session" +// +// _ "github.com/beego/beego/v2/server/web/session/memcache" +// "github.com/beego/beego/v2/server/web/session" +// // ) // // func init() { // globalSessions, _ = session.NewManager("memcache", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:11211"}``) // go globalSessions.GC() // } -// -// more docs: http://beego.me/docs/module/session.md package memcache import ( + "context" "net/http" "strings" "sync" - "github.com/astaxie/beego/session" - "github.com/bradfitz/gomemcache/memcache" + + "github.com/beego/beego/v2/server/web/session" ) -var mempder = &MemProvider{} -var client *memcache.Client +var ( + mempder = &MemProvider{} + client *memcache.Client +) // SessionStore memcache session store type SessionStore struct { @@ -54,7 +57,7 @@ type SessionStore struct { } // Set value in memcache session -func (rs *SessionStore) Set(key, value interface{}) error { +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values[key] = value @@ -62,7 +65,7 @@ func (rs *SessionStore) Set(key, value interface{}) error { } // Get value in memcache session -func (rs *SessionStore) Get(key interface{}) interface{} { +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { rs.lock.RLock() defer rs.lock.RUnlock() if v, ok := rs.values[key]; ok { @@ -72,7 +75,7 @@ func (rs *SessionStore) Get(key interface{}) interface{} { } // Delete value in memcache session -func (rs *SessionStore) Delete(key interface{}) error { +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { rs.lock.Lock() defer rs.lock.Unlock() delete(rs.values, key) @@ -80,7 +83,7 @@ func (rs *SessionStore) Delete(key interface{}) error { } // Flush clear all values in memcache session -func (rs *SessionStore) Flush() error { +func (rs *SessionStore) Flush(context.Context) error { rs.lock.Lock() defer rs.lock.Unlock() rs.values = make(map[interface{}]interface{}) @@ -88,18 +91,34 @@ func (rs *SessionStore) Flush() error { } // SessionID get memcache session id -func (rs *SessionStore) SessionID() string { +func (rs *SessionStore) SessionID(context.Context) string { return rs.sid } // SessionRelease save session values to memcache -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { + rs.releaseSession(ctx, w, false) +} + +// SessionReleaseIfPresent save session values to memcache when key is present +func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { + rs.releaseSession(ctx, w, true) +} + +func (rs *SessionStore) releaseSession(_ context.Context, _ http.ResponseWriter, requirePresent bool) { + rs.lock.RLock() + values := rs.values + rs.lock.RUnlock() + b, err := session.EncodeGob(values) if err != nil { return } item := memcache.Item{Key: rs.sid, Value: b, Expiration: int32(rs.maxlifetime)} - client.Set(&item) + if requirePresent { + client.Replace(&item) + } else { + client.Set(&item) + } } // MemProvider memcache session provider @@ -113,7 +132,7 @@ type MemProvider struct { // SessionInit init memcache session // savepath like // e.g. 127.0.0.1:9090 -func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { +func (rp *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { rp.maxlifetime = maxlifetime rp.conninfo = strings.Split(savePath, ";") client = memcache.New(rp.conninfo...) @@ -121,7 +140,7 @@ func (rp *MemProvider) SessionInit(maxlifetime int64, savePath string) error { } // SessionRead read memcache session by sid -func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { +func (rp *MemProvider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -149,20 +168,20 @@ func (rp *MemProvider) SessionRead(sid string) (session.Store, error) { } // SessionExist check memcache session exist by sid -func (rp *MemProvider) SessionExist(sid string) bool { +func (rp *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) { if client == nil { if err := rp.connectInit(); err != nil { - return false + return false, err } } if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - return false + return false, err } - return true + return true, nil } // SessionRegenerate generate new sid for memcache session -func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (rp *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { if client == nil { if err := rp.connectInit(); err != nil { return nil, err @@ -170,8 +189,8 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, err } var contain []byte if item, err := client.Get(sid); err != nil || len(item.Value) == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error + // oldsid doesn't exist, set the new sid directly + // ignore error here, since if it returns error // the existed value will be 0 item.Key = sid item.Value = []byte("") @@ -201,7 +220,7 @@ func (rp *MemProvider) SessionRegenerate(oldsid, sid string) (session.Store, err } // SessionDestroy delete memcache session by id -func (rp *MemProvider) SessionDestroy(sid string) error { +func (rp *MemProvider) SessionDestroy(ctx context.Context, sid string) error { if client == nil { if err := rp.connectInit(); err != nil { return err @@ -216,12 +235,12 @@ func (rp *MemProvider) connectInit() error { return nil } -// SessionGC Impelment method, no used. -func (rp *MemProvider) SessionGC() { +// SessionGC Implement method, no used. +func (rp *MemProvider) SessionGC(context.Context) { } // SessionAll return all activeSession -func (rp *MemProvider) SessionAll() int { +func (rp *MemProvider) SessionAll(context.Context) int { return 0 } diff --git a/session/mysql/sess_mysql.go b/server/web/session/mysql/sess_mysql.go similarity index 74% rename from session/mysql/sess_mysql.go rename to server/web/session/mysql/sess_mysql.go index 301353ab37..fc5e0aaef4 100644 --- a/session/mysql/sess_mysql.go +++ b/server/web/session/mysql/sess_mysql.go @@ -19,6 +19,7 @@ // go install github.com/go-sql-driver/mysql // // mysql session support need create table as sql: +// // CREATE TABLE `session` ( // `session_key` char(64) NOT NULL, // `session_data` blob, @@ -28,27 +29,28 @@ // // Usage: // import( -// _ "github.com/astaxie/beego/session/mysql" -// "github.com/astaxie/beego/session" +// +// _ "github.com/beego/beego/v2/server/web/session/mysql" +// "github.com/beego/beego/v2/server/web/session" +// // ) // // func init() { // globalSessions, _ = session.NewManager("mysql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]"}``) // go globalSessions.GC() // } -// -// more docs: http://beego.me/docs/module/session.md package mysql import ( + "context" "database/sql" "net/http" "sync" "time" - "github.com/astaxie/beego/session" - // import mysql driver _ "github.com/go-sql-driver/mysql" + + "github.com/beego/beego/v2/server/web/session" ) var ( @@ -67,7 +69,7 @@ type SessionStore struct { // Set value in mysql session. // it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -75,7 +77,7 @@ func (st *SessionStore) Set(key, value interface{}) error { } // Get value from mysql session -func (st *SessionStore) Get(key interface{}) interface{} { +func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -85,7 +87,7 @@ func (st *SessionStore) Get(key interface{}) interface{} { } // Delete value in mysql session -func (st *SessionStore) Delete(key interface{}) error { +func (st *SessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -93,7 +95,7 @@ func (st *SessionStore) Delete(key interface{}) error { } // Flush clear all values in mysql session -func (st *SessionStore) Flush() error { +func (st *SessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -101,15 +103,18 @@ func (st *SessionStore) Flush() error { } // SessionID get session id of this mysql session store -func (st *SessionStore) SessionID() string { +func (st *SessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease save mysql session values to database. // must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { defer st.c.Close() - b, err := session.EncodeGob(st.values) + st.lock.RLock() + values := st.values + st.lock.RUnlock() + b, err := session.EncodeGob(values) if err != nil { return } @@ -117,6 +122,11 @@ func (st *SessionStore) SessionRelease(w http.ResponseWriter) { b, time.Now().Unix(), st.sid) } +// SessionReleaseIfPresent save mysql session values to database. +func (st *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { + st.SessionRelease(ctx, w) +} + // Provider mysql session provider type Provider struct { maxlifetime int64 @@ -134,14 +144,14 @@ func (mp *Provider) connectInit() *sql.DB { // SessionInit init mysql session. // savepath is the connection string of mysql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } // SessionRead get mysql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { +func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte @@ -149,6 +159,8 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { if err == sql.ErrNoRows { c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix()) + } else if err != nil { + return nil, err } var kv map[interface{}]interface{} if len(sessiondata) == 0 { @@ -164,17 +176,23 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check mysql session exist -func (mp *Provider) SessionExist(sid string) bool { +func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) var sessiondata []byte err := row.Scan(&sessiondata) - return err != sql.ErrNoRows + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil } // SessionRegenerate generate new sid for mysql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) var sessiondata []byte @@ -182,7 +200,10 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) if err == sql.ErrNoRows { c.Exec("insert into "+TableName+"(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix()) } - c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + _, err = c.Exec("update "+TableName+" set `session_key`=? where session_key=?", sid, oldsid) + if err != nil { + return nil, err + } var kv map[interface{}]interface{} if len(sessiondata) == 0 { kv = make(map[interface{}]interface{}) @@ -197,7 +218,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy delete mysql session by sid -func (mp *Provider) SessionDestroy(sid string) error { +func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := mp.connectInit() c.Exec("DELETE FROM "+TableName+" where session_key=?", sid) c.Close() @@ -205,14 +226,14 @@ func (mp *Provider) SessionDestroy(sid string) error { } // SessionGC delete expired values in mysql session -func (mp *Provider) SessionGC() { +func (mp *Provider) SessionGC(context.Context) { c := mp.connectInit() c.Exec("DELETE from "+TableName+" where session_expiry < ?", time.Now().Unix()-mp.maxlifetime) c.Close() } // SessionAll count values in mysql session -func (mp *Provider) SessionAll() int { +func (mp *Provider) SessionAll(context.Context) int { c := mp.connectInit() defer c.Close() var total int diff --git a/session/postgres/sess_postgresql.go b/server/web/session/postgres/sess_postgresql.go similarity index 77% rename from session/postgres/sess_postgresql.go rename to server/web/session/postgres/sess_postgresql.go index 0b8b96457b..4e6f5e4a40 100644 --- a/session/postgres/sess_postgresql.go +++ b/server/web/session/postgres/sess_postgresql.go @@ -18,7 +18,6 @@ // // go install github.com/lib/pq // -// // needs this table in your database: // // CREATE TABLE session ( @@ -35,30 +34,30 @@ // SessionSavePath = "user=a password=b dbname=c sslmode=disable" // SessionName = session // -// // Usage: // import( -// _ "github.com/astaxie/beego/session/postgresql" -// "github.com/astaxie/beego/session" +// +// _ "github.com/beego/beego/v2/server/web/session/postgresql" +// "github.com/beego/beego/v2/server/web/session" +// // ) // // func init() { // globalSessions, _ = session.NewManager("postgresql", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"user=pqgotest dbname=pqgotest sslmode=verify-full"}``) // go globalSessions.GC() // } -// -// more docs: http://beego.me/docs/module/session.md package postgres import ( + "context" "database/sql" "net/http" "sync" "time" - "github.com/astaxie/beego/session" - // import postgresql Driver _ "github.com/lib/pq" + + "github.com/beego/beego/v2/server/web/session" ) var postgresqlpder = &Provider{} @@ -73,7 +72,7 @@ type SessionStore struct { // Set value in postgresql session. // it is temp value in map. -func (st *SessionStore) Set(key, value interface{}) error { +func (st *SessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -81,7 +80,7 @@ func (st *SessionStore) Set(key, value interface{}) error { } // Get value from postgresql session -func (st *SessionStore) Get(key interface{}) interface{} { +func (st *SessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -91,7 +90,7 @@ func (st *SessionStore) Get(key interface{}) interface{} { } // Delete value in postgresql session -func (st *SessionStore) Delete(key interface{}) error { +func (st *SessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -99,7 +98,7 @@ func (st *SessionStore) Delete(key interface{}) error { } // Flush clear all values in postgresql session -func (st *SessionStore) Flush() error { +func (st *SessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -107,21 +106,28 @@ func (st *SessionStore) Flush() error { } // SessionID get session id of this postgresql session store -func (st *SessionStore) SessionID() string { +func (st *SessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease save postgresql session values to database. // must call this method to save values to database. -func (st *SessionStore) SessionRelease(w http.ResponseWriter) { +func (st *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { defer st.c.Close() - b, err := session.EncodeGob(st.values) + st.lock.RLock() + values := st.values + st.lock.RUnlock() + b, err := session.EncodeGob(values) if err != nil { return } st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3", b, time.Now().Format(time.RFC3339), st.sid) +} +// SessionReleaseIfPresent save postgresql session values to database when key is present +func (st *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { + st.SessionRelease(ctx, w) } // Provider postgresql session provider @@ -141,14 +147,14 @@ func (mp *Provider) connectInit() *sql.DB { // SessionInit init postgresql session. // savepath is the connection string of postgresql. -func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { +func (mp *Provider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { mp.maxlifetime = maxlifetime mp.savePath = savePath return nil } // SessionRead get postgresql session by sid -func (mp *Provider) SessionRead(sid string) (session.Store, error) { +func (mp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte @@ -178,17 +184,23 @@ func (mp *Provider) SessionRead(sid string) (session.Store, error) { } // SessionExist check postgresql session exist -func (mp *Provider) SessionExist(sid string) bool { +func (mp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { c := mp.connectInit() defer c.Close() row := c.QueryRow("select session_data from session where session_key=$1", sid) var sessiondata []byte err := row.Scan(&sessiondata) - return err != sql.ErrNoRows + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil } // SessionRegenerate generate new sid for postgresql session -func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { +func (mp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { c := mp.connectInit() row := c.QueryRow("select session_data from session where session_key=$1", oldsid) var sessiondata []byte @@ -212,7 +224,7 @@ func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy delete postgresql session by sid -func (mp *Provider) SessionDestroy(sid string) error { +func (mp *Provider) SessionDestroy(ctx context.Context, sid string) error { c := mp.connectInit() c.Exec("DELETE FROM session where session_key=$1", sid) c.Close() @@ -220,14 +232,14 @@ func (mp *Provider) SessionDestroy(sid string) error { } // SessionGC delete expired values in postgresql session -func (mp *Provider) SessionGC() { +func (mp *Provider) SessionGC(context.Context) { c := mp.connectInit() c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime) c.Close() } // SessionAll count values in postgresql session -func (mp *Provider) SessionAll() int { +func (mp *Provider) SessionAll(context.Context) int { c := mp.connectInit() defer c.Close() var total int diff --git a/server/web/session/redis/sess_redis.go b/server/web/session/redis/sess_redis.go new file mode 100644 index 0000000000..ccb51a600a --- /dev/null +++ b/server/web/session/redis/sess_redis.go @@ -0,0 +1,297 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/gomodule/redigo/redis +// +// go install github.com/gomodule/redigo/redis +// +// Usage: +// import( +// +// _ "github.com/beego/beego/v2/server/web/session/redis" +// "github.com/beego/beego/v2/server/web/session" +// +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) +// go globalSessions.GC() +// } +package redis + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/redis/go-redis/v9" + + "github.com/beego/beego/v2/server/web/session" +) + +var redispder = &Provider{} + +// MaxPoolSize redis max pool size +var MaxPoolSize = 100 + +// SessionStore redis session store +type SessionStore struct { + p *redis.Client + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis session +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis session +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis session +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis session +func (rs *SessionStore) Flush(context.Context) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis session id +func (rs *SessionStore) SessionID(context.Context) string { + return rs.sid +} + +// SessionRelease save session values to redis +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { + rs.releaseSession(ctx, w, false) +} + +// SessionReleaseIfPresent save session values to redis when key is present +func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { + rs.releaseSession(ctx, w, true) +} + +func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) { + rs.lock.RLock() + values := rs.values + rs.lock.RUnlock() + b, err := session.EncodeGob(values) + if err != nil { + return + } + c := rs.p + if requirePresent { + c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) + } else { + c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) + } +} + +// Provider redis session provider +type Provider struct { + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *redis.Client +} + +// SessionInit init redis session +// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second +// v1.x e.g. 127.0.0.1:6379,100,astaxie,0,30 +// v2.0 you should pass json string +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { + rp.maxlifetime = maxlifetime + + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = redis.NewClient(&redis.Options{ + Addr: rp.SavePath, + Password: rp.Password, + PoolSize: rp.Poolsize, + DB: rp.DbNum, + ConnMaxIdleTime: rp.idleTimeout, + MaxRetries: rp.MaxRetries, + }) + + return rp.poollist.Ping(ctx).Err() +} + +func (rp *Provider) initOldStyle(savePath string) { + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.SavePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.Poolsize = MaxPoolSize + } else { + rp.Poolsize = poolsize + } + } else { + rp.Poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.Password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.DbNum = 0 + } else { + rp.DbNum = dbnum + } + } else { + rp.DbNum = 0 + } + if len(configs) > 4 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + rp.idleTimeout = time.Duration(timeout) * time.Second + } + } + if len(configs) > 5 { + checkFrequency, err := strconv.Atoi(configs[5]) + if err == nil && checkFrequency > 0 { + rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second + } + } + if len(configs) > 6 { + retries, err := strconv.Atoi(configs[6]) + if err == nil && retries > 0 { + rp.MaxRetries = retries + } + } +} + +// SessionRead read redis session by sid +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { + var kv map[interface{}]interface{} + + kvs, err := rp.poollist.Get(ctx, sid).Result() + if err != nil && err != redis.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis session exist by sid +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { + c := rp.poollist + + if existed, err := c.Exists(ctx, sid).Result(); err != nil || existed == 0 { + return false, err + } + return true, nil +} + +// SessionRegenerate generate new sid for redis session +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + c := rp.poollist + if existed, _ := c.Exists(ctx, oldsid).Result(); existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Do(ctx, "SET", sid, "", "EX", rp.maxlifetime) + } else { + c.Rename(ctx, oldsid, sid) + c.Expire(ctx, sid, time.Duration(rp.maxlifetime)*time.Second) + } + return rp.SessionRead(ctx, sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { + c := rp.poollist + + c.Del(ctx, sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC(context.Context) { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll(context.Context) int { + return 0 +} + +func init() { + session.Register("redis", redispder) +} diff --git a/server/web/session/redis/sess_redis_test.go b/server/web/session/redis/sess_redis_test.go new file mode 100644 index 0000000000..424c3aa770 --- /dev/null +++ b/server/web/session/redis/sess_redis_test.go @@ -0,0 +1,161 @@ +package redis + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web/session" +) + +func TestRedis(t *testing.T) { + globalSession, err := setupSessionManager(t) + if err != nil { + t.Fatal(err) + } + + go globalSession.GC() + + ctx := context.Background() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSession.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(ctx, w) + + // SET AND GET + err = sess.Set(ctx, "username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get(ctx, "username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete(ctx, "username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get(ctx, "username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set(ctx, "username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set(ctx, "password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get(ctx, "username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get(ctx, "password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush(ctx) + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get(ctx, "username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get(ctx, "password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(ctx, w) +} + +func TestProvider_SessionInit(t *testing.T) { + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} + +func TestStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) { + globalSessions, err := setupSessionManager(t) + if err != nil { + t.Fatal(err) + } + // todo test if e==nil + go globalSessions.GC() + + ctx := context.Background() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + + if err := globalSessions.GetProvider().SessionDestroy(ctx, sess.SessionID(ctx)); err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + sess.SessionReleaseIfPresent(ctx, httptest.NewRecorder()) + }() + wg.Wait() + exist, err := globalSessions.GetProvider().SessionExist(ctx, sess.SessionID(ctx)) + if err != nil { + t.Error(err) + } + if exist { + t.Fatalf("session %s should exist", sess.SessionID(ctx)) + } +} + +func setupSessionManager(t *testing.T) (*session.Manager, error) { + redisAddr := os.Getenv("REDIS_ADDR") + if redisAddr == "" { + redisAddr = "127.0.0.1:6379" + } + redisConfig := fmt.Sprintf("%s,100,,0,30", redisAddr) + + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig(redisConfig), + ) + globalSessions, err := session.NewManager("redis", sessionConfig) + if err != nil { + t.Log("could not create manager: ", err) + return nil, err + } + return globalSessions, nil +} diff --git a/server/web/session/redis_cluster/redis_cluster.go b/server/web/session/redis_cluster/redis_cluster.go new file mode 100644 index 0000000000..5a6676b5e9 --- /dev/null +++ b/server/web/session/redis_cluster/redis_cluster.go @@ -0,0 +1,292 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/redis/go-redis +// +// go install github.com/redis/go-redis +// +// Usage: +// import( +// +// _ "github.com/beego/beego/v2/server/web/session/redis_cluster" +// "github.com/beego/beego/v2/server/web/session" +// +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) +// go globalSessions.GC() +// } +package redis_cluster + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "strings" + "sync" + "time" + + rediss "github.com/redis/go-redis/v9" + + "github.com/beego/beego/v2/server/web/session" +) + +var redispder = &Provider{} + +// MaxPoolSize redis_cluster max pool size +var MaxPoolSize = 1000 + +// SessionStore redis_cluster session store +type SessionStore struct { + p *rediss.ClusterClient + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_cluster session +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_cluster session +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_cluster session +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_cluster session +func (rs *SessionStore) Flush(context.Context) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_cluster session id +func (rs *SessionStore) SessionID(context.Context) string { + return rs.sid +} + +// SessionRelease save session values to redis_cluster +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { + rs.releaseSession(ctx, w, false) +} + +// SessionReleaseIfPresent save session values to redis_cluster when key is present +func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { + rs.releaseSession(ctx, w, true) +} + +func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) { + rs.lock.RLock() + values := rs.values + rs.lock.RUnlock() + b, err := session.EncodeGob(values) + if err != nil { + return + } + c := rs.p + if requirePresent { + c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) + } else { + c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) + } +} + +// Provider redis_cluster session provider +type Provider struct { + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *rediss.ClusterClient +} + +// SessionInit init redis_cluster session +// cfgStr like redis server addr,pool size,password,dbnum +// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { + rp.maxlifetime = maxlifetime + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ + Addrs: strings.Split(rp.SavePath, ";"), + Password: rp.Password, + PoolSize: rp.Poolsize, + ConnMaxIdleTime: rp.idleTimeout, + MaxRetries: rp.MaxRetries, + }) + return rp.poollist.Ping(ctx).Err() +} + +// for v1.x +func (rp *Provider) initOldStyle(savePath string) { + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.SavePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.Poolsize = MaxPoolSize + } else { + rp.Poolsize = poolsize + } + } else { + rp.Poolsize = MaxPoolSize + } + if len(configs) > 2 { + rp.Password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.DbNum = 0 + } else { + rp.DbNum = dbnum + } + } else { + rp.DbNum = 0 + } + if len(configs) > 4 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + rp.idleTimeout = time.Duration(timeout) * time.Second + } + } + if len(configs) > 5 { + checkFrequency, err := strconv.Atoi(configs[5]) + if err == nil && checkFrequency > 0 { + rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second + } + } + if len(configs) > 6 { + retries, err := strconv.Atoi(configs[6]) + if err == nil && retries > 0 { + rp.MaxRetries = retries + } + } +} + +// SessionRead read redis_cluster session by sid +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(ctx, sid).Result() + if err != nil && err != rediss.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_cluster session exist by sid +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { + c := rp.poollist + if existed, err := c.Exists(ctx, sid).Result(); err != nil || existed == 0 { + return false, err + } + return true, nil +} + +// SessionRegenerate generate new sid for redis_cluster session +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(ctx, oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(ctx, sid, "", time.Duration(rp.maxlifetime)*time.Second) + } else { + c.Rename(ctx, oldsid, sid) + c.Expire(ctx, sid, time.Duration(rp.maxlifetime)*time.Second) + } + return rp.SessionRead(ctx, sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { + c := rp.poollist + c.Del(ctx, sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC(context.Context) { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll(context.Context) int { + return 0 +} + +func init() { + session.Register("redis_cluster", redispder) +} diff --git a/server/web/session/redis_cluster/redis_cluster_test.go b/server/web/session/redis_cluster/redis_cluster_test.go new file mode 100644 index 0000000000..e518af54f0 --- /dev/null +++ b/server/web/session/redis_cluster/redis_cluster_test.go @@ -0,0 +1,34 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis_cluster + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel.go b/server/web/session/redis_sentinel/sess_redis_sentinel.go new file mode 100644 index 0000000000..c0dc75101c --- /dev/null +++ b/server/web/session/redis_sentinel/sess_redis_sentinel.go @@ -0,0 +1,307 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package redis for session provider +// +// depend on github.com/redis/go-redis +// +// go install github.com/redis/go-redis +// +// Usage: +// import( +// +// _ "github.com/beego/beego/v2/server/web/session/redis_sentinel" +// "github.com/beego/beego/v2/server/web/session" +// +// ) +// +// func init() { +// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) +// go globalSessions.GC() +// } +// +// more detail about params: please check the notes on the function SessionInit in this package +package redis_sentinel + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/redis/go-redis/v9" + + "github.com/beego/beego/v2/server/web/session" +) + +var redispder = &Provider{} + +// DefaultPoolSize redis_sentinel default pool size +var DefaultPoolSize = 100 + +// SessionStore redis_sentinel session store +type SessionStore struct { + p *redis.Client + sid string + lock sync.RWMutex + values map[interface{}]interface{} + maxlifetime int64 +} + +// Set value in redis_sentinel session +func (rs *SessionStore) Set(ctx context.Context, key, value interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values[key] = value + return nil +} + +// Get value in redis_sentinel session +func (rs *SessionStore) Get(ctx context.Context, key interface{}) interface{} { + rs.lock.RLock() + defer rs.lock.RUnlock() + if v, ok := rs.values[key]; ok { + return v + } + return nil +} + +// Delete value in redis_sentinel session +func (rs *SessionStore) Delete(ctx context.Context, key interface{}) error { + rs.lock.Lock() + defer rs.lock.Unlock() + delete(rs.values, key) + return nil +} + +// Flush clear all values in redis_sentinel session +func (rs *SessionStore) Flush(context.Context) error { + rs.lock.Lock() + defer rs.lock.Unlock() + rs.values = make(map[interface{}]interface{}) + return nil +} + +// SessionID get redis_sentinel session id +func (rs *SessionStore) SessionID(context.Context) string { + return rs.sid +} + +// SessionRelease save session values to redis_sentinel +func (rs *SessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { + rs.releaseSession(ctx, w, false) +} + +// SessionReleaseIfPresent save session values to redis_sentinel when key is present +func (rs *SessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { + rs.releaseSession(ctx, w, true) +} + +func (rs *SessionStore) releaseSession(ctx context.Context, _ http.ResponseWriter, requirePresent bool) { + rs.lock.RLock() + values := rs.values + rs.lock.RUnlock() + b, err := session.EncodeGob(values) + if err != nil { + return + } + c := rs.p + if requirePresent { + c.SetXX(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) + } else { + c.Set(ctx, rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) + } +} + +// Provider redis_sentinel session provider +type Provider struct { + maxlifetime int64 + SavePath string `json:"save_path"` + Poolsize int `json:"poolsize"` + Password string `json:"password"` + DbNum int `json:"db_num"` + + idleTimeout time.Duration + IdleTimeoutStr string `json:"idle_timeout"` + + idleCheckFrequency time.Duration + IdleCheckFrequencyStr string `json:"idle_check_frequency"` + MaxRetries int `json:"max_retries"` + poollist *redis.Client + MasterName string `json:"master_name"` +} + +// SessionInit init redis_sentinel session +// cfgStr like redis sentinel addr,pool size,password,dbnum,masterName +// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster +func (rp *Provider) SessionInit(ctx context.Context, maxlifetime int64, cfgStr string) error { + rp.maxlifetime = maxlifetime + cfgStr = strings.TrimSpace(cfgStr) + // we think cfgStr is v2.0, using json to init the session + if strings.HasPrefix(cfgStr, "{") { + err := json.Unmarshal([]byte(cfgStr), rp) + if err != nil { + return err + } + rp.idleTimeout, err = time.ParseDuration(rp.IdleTimeoutStr) + if err != nil { + return err + } + + rp.idleCheckFrequency, err = time.ParseDuration(rp.IdleCheckFrequencyStr) + if err != nil { + return err + } + + } else { + rp.initOldStyle(cfgStr) + } + + rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ + SentinelAddrs: strings.Split(rp.SavePath, ";"), + Password: rp.Password, + PoolSize: rp.Poolsize, + DB: rp.DbNum, + MasterName: rp.MasterName, + ConnMaxIdleTime: rp.idleTimeout, + MaxRetries: rp.MaxRetries, + }) + + return rp.poollist.Ping(ctx).Err() +} + +// for v1.x +func (rp *Provider) initOldStyle(savePath string) { + configs := strings.Split(savePath, ",") + if len(configs) > 0 { + rp.SavePath = configs[0] + } + if len(configs) > 1 { + poolsize, err := strconv.Atoi(configs[1]) + if err != nil || poolsize < 0 { + rp.Poolsize = DefaultPoolSize + } else { + rp.Poolsize = poolsize + } + } else { + rp.Poolsize = DefaultPoolSize + } + if len(configs) > 2 { + rp.Password = configs[2] + } + if len(configs) > 3 { + dbnum, err := strconv.Atoi(configs[3]) + if err != nil || dbnum < 0 { + rp.DbNum = 0 + } else { + rp.DbNum = dbnum + } + } else { + rp.DbNum = 0 + } + if len(configs) > 4 { + if configs[4] != "" { + rp.MasterName = configs[4] + } else { + rp.MasterName = "mymaster" + } + } else { + rp.MasterName = "mymaster" + } + if len(configs) > 5 { + timeout, err := strconv.Atoi(configs[4]) + if err == nil && timeout > 0 { + rp.idleTimeout = time.Duration(timeout) * time.Second + } + } + if len(configs) > 6 { + checkFrequency, err := strconv.Atoi(configs[5]) + if err == nil && checkFrequency > 0 { + rp.idleCheckFrequency = time.Duration(checkFrequency) * time.Second + } + } + if len(configs) > 7 { + retries, err := strconv.Atoi(configs[6]) + if err == nil && retries > 0 { + rp.MaxRetries = retries + } + } +} + +// SessionRead read redis_sentinel session by sid +func (rp *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { + var kv map[interface{}]interface{} + kvs, err := rp.poollist.Get(ctx, sid).Result() + if err != nil && err != redis.Nil { + return nil, err + } + if len(kvs) == 0 { + kv = make(map[interface{}]interface{}) + } else { + if kv, err = session.DecodeGob([]byte(kvs)); err != nil { + return nil, err + } + } + + rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} + return rs, nil +} + +// SessionExist check redis_sentinel session exist by sid +func (rp *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { + c := rp.poollist + if existed, err := c.Exists(ctx, sid).Result(); err != nil || existed == 0 { + return false, err + } + return true, nil +} + +// SessionRegenerate generate new sid for redis_sentinel session +func (rp *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + c := rp.poollist + + if existed, err := c.Exists(ctx, oldsid).Result(); err != nil || existed == 0 { + // oldsid doesn't exists, set the new sid directly + // ignore error here, since if it return error + // the existed value will be 0 + c.Set(ctx, sid, "", time.Duration(rp.maxlifetime)*time.Second) + } else { + c.Rename(ctx, oldsid, sid) + c.Expire(ctx, sid, time.Duration(rp.maxlifetime)*time.Second) + } + return rp.SessionRead(ctx, sid) +} + +// SessionDestroy delete redis session by id +func (rp *Provider) SessionDestroy(ctx context.Context, sid string) error { + c := rp.poollist + c.Del(ctx, sid) + return nil +} + +// SessionGC Impelment method, no used. +func (rp *Provider) SessionGC(context.Context) { +} + +// SessionAll return all activeSession +func (rp *Provider) SessionAll(context.Context) int { + return 0 +} + +func init() { + session.Register("redis_sentinel", redispder) +} diff --git a/server/web/session/redis_sentinel/sess_redis_sentinel_test.go b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go new file mode 100644 index 0000000000..3e47230a91 --- /dev/null +++ b/server/web/session/redis_sentinel/sess_redis_sentinel_test.go @@ -0,0 +1,155 @@ +package redis_sentinel + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/server/web/session" +) + +func TestRedisSentinel(t *testing.T) { + globalSessions, err := setupSessionManager(t) + if err != nil { + t.Log(err) + return + } + // todo test if e==nil + go globalSessions.GC() + + ctx := context.Background() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + defer sess.SessionRelease(ctx, w) + + // SET AND GET + err = sess.Set(ctx, "username", "astaxie") + if err != nil { + t.Fatal("set username failed:", err) + } + username := sess.Get(ctx, "username") + if username != "astaxie" { + t.Fatal("get username failed") + } + + // DELETE + err = sess.Delete(ctx, "username") + if err != nil { + t.Fatal("delete username failed:", err) + } + username = sess.Get(ctx, "username") + if username != nil { + t.Fatal("delete username failed") + } + + // FLUSH + err = sess.Set(ctx, "username", "astaxie") + if err != nil { + t.Fatal("set failed:", err) + } + err = sess.Set(ctx, "password", "1qaz2wsx") + if err != nil { + t.Fatal("set failed:", err) + } + username = sess.Get(ctx, "username") + if username != "astaxie" { + t.Fatal("get username failed") + } + password := sess.Get(ctx, "password") + if password != "1qaz2wsx" { + t.Fatal("get password failed") + } + err = sess.Flush(ctx) + if err != nil { + t.Fatal("flush failed:", err) + } + username = sess.Get(ctx, "username") + if username != nil { + t.Fatal("flush failed") + } + password = sess.Get(ctx, "password") + if password != nil { + t.Fatal("flush failed") + } + + sess.SessionRelease(ctx, w) +} + +func TestProvider_SessionInit(t *testing.T) { + savePath := ` +{ "save_path": "my save path", "idle_timeout": "3s"} +` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "my save path", cp.SavePath) + assert.Equal(t, 3*time.Second, cp.idleTimeout) + assert.Equal(t, int64(12), cp.maxlifetime) +} + +func TestStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) { + globalSessions, e := setupSessionManager(t) + if e != nil { + t.Log(e) + return + } + // todo test if e==nil + go globalSessions.GC() + + ctx := context.Background() + + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + sess, err := globalSessions.SessionStart(w, r) + if err != nil { + t.Fatal("session start failed:", err) + } + + if err := globalSessions.GetProvider().SessionDestroy(ctx, sess.SessionID(ctx)); err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + sess.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder()) + }() + wg.Wait() + exist, err := globalSessions.GetProvider().SessionExist(ctx, sess.SessionID(ctx)) + if err != nil { + t.Error(err) + } + if exist { + t.Fatalf("session %s should exist", sess.SessionID(ctx)) + } +} + +func setupSessionManager(t *testing.T) (*session.Manager, error) { + sessionConfig := session.NewManagerConfig( + session.CfgCookieName(`gosessionid`), + session.CfgSetCookie(true), + session.CfgGcLifeTime(3600), + session.CfgMaxLifeTime(3600), + session.CfgSecure(false), + session.CfgCookieLifeTime(3600), + session.CfgProviderConfig("127.0.0.1:6379,100,,0,master"), + ) + globalSessions, err := session.NewManager("redis_sentinel", sessionConfig) + if err != nil { + t.Log(err) + return nil, err + } + return globalSessions, nil +} diff --git a/session/sess_cookie.go b/server/web/session/sess_cookie.go similarity index 64% rename from session/sess_cookie.go rename to server/web/session/sess_cookie.go index 145e53c9bc..822615b112 100644 --- a/session/sess_cookie.go +++ b/server/web/session/sess_cookie.go @@ -15,6 +15,7 @@ package session import ( + "context" "crypto/aes" "crypto/cipher" "encoding/json" @@ -34,7 +35,7 @@ type CookieSessionStore struct { // Set value to cookie session. // the value are encoded as gob with hash block string. -func (st *CookieSessionStore) Set(key, value interface{}) error { +func (st *CookieSessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.values[key] = value @@ -42,7 +43,7 @@ func (st *CookieSessionStore) Set(key, value interface{}) error { } // Get value from cookie session -func (st *CookieSessionStore) Get(key interface{}) interface{} { +func (st *CookieSessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.values[key]; ok { @@ -52,7 +53,7 @@ func (st *CookieSessionStore) Get(key interface{}) interface{} { } // Delete value in cookie session -func (st *CookieSessionStore) Delete(key interface{}) error { +func (st *CookieSessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.values, key) @@ -60,7 +61,7 @@ func (st *CookieSessionStore) Delete(key interface{}) error { } // Flush Clean all values in cookie session -func (st *CookieSessionStore) Flush() error { +func (st *CookieSessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.values = make(map[interface{}]interface{}) @@ -68,24 +69,36 @@ func (st *CookieSessionStore) Flush() error { } // SessionID Return id of this cookie session -func (st *CookieSessionStore) SessionID() string { +func (st *CookieSessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease Write cookie session to http response cookie -func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) { - encodedCookie, err := encodeCookie(cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, st.values) +func (st *CookieSessionStore) SessionRelease(_ context.Context, w http.ResponseWriter) { + st.lock.RLock() + values := st.values + st.lock.RUnlock() + encodedCookie, err := encodeCookie( + cookiepder.block, cookiepder.config.SecurityKey, cookiepder.config.SecurityName, values) if err == nil { - cookie := &http.Cookie{Name: cookiepder.config.CookieName, + cookie := &http.Cookie{ + Name: cookiepder.config.CookieName, Value: url.QueryEscape(encodedCookie), Path: "/", HttpOnly: true, Secure: cookiepder.config.Secure, - MaxAge: cookiepder.config.Maxage} + MaxAge: cookiepder.config.Maxage, + } http.SetCookie(w, cookie) } } +// SessionReleaseIfPresent Write cookie session to http response cookie when it is present +// This is a no-op for cookie sessions, because they are always present. +func (st *CookieSessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { + st.SessionRelease(ctx, w) +} + type cookieConfig struct { SecurityKey string `json:"securityKey"` BlockKey string `json:"blockKey"` @@ -105,12 +118,13 @@ type CookieProvider struct { // SessionInit Init cookie session provider with max lifetime and config json. // maxlifetime is ignored. // json config: -// securityKey - hash string -// blockKey - gob encode hash string. it's saved as aes crypto. -// securityName - recognized name in encoded cookie string -// cookieName - cookie name -// maxage - cookie max life time. -func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error { +// +// securityKey - hash string +// blockKey - gob encode hash string. it's saved as aes crypto. +// securityName - recognized name in encoded cookie string +// cookieName - cookie name +// maxage - cookie max life time. +func (pder *CookieProvider) SessionInit(ctx context.Context, maxlifetime int64, config string) error { pder.config = &cookieConfig{} err := json.Unmarshal([]byte(config), pder.config) if err != nil { @@ -132,7 +146,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error // SessionRead Get SessionStore in cooke. // decode cooke string to map and put into SessionStore with sid. -func (pder *CookieProvider) SessionRead(sid string) (Store, error) { +func (pder *CookieProvider) SessionRead(ctx context.Context, sid string) (Store, error) { maps, _ := decodeCookie(pder.block, pder.config.SecurityKey, pder.config.SecurityName, @@ -145,31 +159,31 @@ func (pder *CookieProvider) SessionRead(sid string) (Store, error) { } // SessionExist Cookie session is always existed -func (pder *CookieProvider) SessionExist(sid string) bool { - return true +func (pder *CookieProvider) SessionExist(ctx context.Context, sid string) (bool, error) { + return true, nil } // SessionRegenerate Implement method, no used. -func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (pder *CookieProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { return nil, nil } // SessionDestroy Implement method, no used. -func (pder *CookieProvider) SessionDestroy(sid string) error { +func (pder *CookieProvider) SessionDestroy(ctx context.Context, sid string) error { return nil } // SessionGC Implement method, no used. -func (pder *CookieProvider) SessionGC() { +func (pder *CookieProvider) SessionGC(context.Context) { } // SessionAll Implement method, return 0. -func (pder *CookieProvider) SessionAll() int { +func (pder *CookieProvider) SessionAll(context.Context) int { return 0 } // SessionUpdate Implement method, no used. -func (pder *CookieProvider) SessionUpdate(sid string) error { +func (pder *CookieProvider) SessionUpdate(ctx context.Context, sid string) error { return nil } diff --git a/session/sess_cookie_test.go b/server/web/session/sess_cookie_test.go similarity index 91% rename from session/sess_cookie_test.go rename to server/web/session/sess_cookie_test.go index b6726005f8..5c1a42568c 100644 --- a/session/sess_cookie_test.go +++ b/server/web/session/sess_cookie_test.go @@ -38,14 +38,14 @@ func TestCookie(t *testing.T) { if err != nil { t.Fatal("set error,", err) } - err = sess.Set("username", "astaxie") + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set error,", err) } - if username := sess.Get("username"); username != "astaxie" { + if username := sess.Get(nil, "username"); username != "astaxie" { t.Fatal("get username error") } - sess.SessionRelease(w) + sess.SessionRelease(nil, w) if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { t.Fatal("setcookie error") } else { @@ -59,7 +59,7 @@ func TestCookie(t *testing.T) { } } -func TestDestorySessionCookie(t *testing.T) { +func TestDestroySessionCookie(t *testing.T) { config := `{"cookieName":"gosessionid","enableSetCookie":true,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}` conf := new(ManagerConfig) if err := json.Unmarshal([]byte(config), conf); err != nil { @@ -85,7 +85,7 @@ func TestDestorySessionCookie(t *testing.T) { if err != nil { t.Fatal("session start err,", err) } - if newSession.SessionID() != session.SessionID() { + if newSession.SessionID(nil) != session.SessionID(nil) { t.Fatal("get cookie session id is not the same again.") } @@ -99,7 +99,7 @@ func TestDestorySessionCookie(t *testing.T) { if err != nil { t.Fatal("session start error") } - if newSession.SessionID() == session.SessionID() { + if newSession.SessionID(nil) == session.SessionID(nil) { t.Fatal("after destroy session and reqeust again ,get cookie session id is same.") } } diff --git a/session/sess_file.go b/server/web/session/sess_file.go similarity index 61% rename from session/sess_file.go rename to server/web/session/sess_file.go index db14352230..45be923e6e 100644 --- a/session/sess_file.go +++ b/server/web/session/sess_file.go @@ -15,12 +15,12 @@ package session import ( + "context" + "errors" "fmt" - "io/ioutil" + "io" "net/http" "os" - "errors" - "path" "path/filepath" "strings" "sync" @@ -40,7 +40,7 @@ type FileSessionStore struct { } // Set value to file session -func (fs *FileSessionStore) Set(key, value interface{}) error { +func (fs *FileSessionStore) Set(ctx context.Context, key, value interface{}) error { fs.lock.Lock() defer fs.lock.Unlock() fs.values[key] = value @@ -48,7 +48,7 @@ func (fs *FileSessionStore) Set(key, value interface{}) error { } // Get value from file session -func (fs *FileSessionStore) Get(key interface{}) interface{} { +func (fs *FileSessionStore) Get(ctx context.Context, key interface{}) interface{} { fs.lock.RLock() defer fs.lock.RUnlock() if v, ok := fs.values[key]; ok { @@ -58,7 +58,7 @@ func (fs *FileSessionStore) Get(key interface{}) interface{} { } // Delete value in file session by given key -func (fs *FileSessionStore) Delete(key interface{}) error { +func (fs *FileSessionStore) Delete(ctx context.Context, key interface{}) error { fs.lock.Lock() defer fs.lock.Unlock() delete(fs.values, key) @@ -66,7 +66,7 @@ func (fs *FileSessionStore) Delete(key interface{}) error { } // Flush Clean all values in file session -func (fs *FileSessionStore) Flush() error { +func (fs *FileSessionStore) Flush(context.Context) error { fs.lock.Lock() defer fs.lock.Unlock() fs.values = make(map[interface{}]interface{}) @@ -74,12 +74,21 @@ func (fs *FileSessionStore) Flush() error { } // SessionID Get file session store id -func (fs *FileSessionStore) SessionID() string { +func (fs *FileSessionStore) SessionID(context.Context) string { return fs.sid } // SessionRelease Write file session to local file with Gob string -func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { +func (fs *FileSessionStore) SessionRelease(ctx context.Context, w http.ResponseWriter) { + fs.releaseSession(ctx, w, true) +} + +// SessionReleaseIfPresent Write file session to local file with Gob string when session exists +func (fs *FileSessionStore) SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) { + fs.releaseSession(ctx, w, false) +} + +func (fs *FileSessionStore) releaseSession(_ context.Context, _ http.ResponseWriter, createIfNotExist bool) { filepder.lock.Lock() defer filepder.lock.Unlock() b, err := EncodeGob(fs.values) @@ -87,16 +96,16 @@ func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) { SLogger.Println(err) return } - _, err = os.Stat(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + _, err = os.Stat(filepath.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) var f *os.File if err == nil { - f, err = os.OpenFile(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0777) + f, err = os.OpenFile(filepath.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid), os.O_RDWR, 0o777) if err != nil { SLogger.Println(err) return } - } else if os.IsNotExist(err) { - f, err = os.Create(path.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) + } else if os.IsNotExist(err) && createIfNotExist { + f, err = os.Create(filepath.Join(filepder.savePath, string(fs.sid[0]), string(fs.sid[1]), fs.sid)) if err != nil { SLogger.Println(err) return @@ -119,7 +128,7 @@ type FileProvider struct { // SessionInit Init file session provider. // savePath sets the session files path. -func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { +func (fp *FileProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { fp.maxlifetime = maxlifetime fp.savePath = savePath return nil @@ -128,9 +137,10 @@ func (fp *FileProvider) SessionInit(maxlifetime int64, savePath string) error { // SessionRead Read file session by sid. // if file is not exist, create it. // the file path is generated from sid string. -func (fp *FileProvider) SessionRead(sid string) (Store, error) { - if strings.ContainsAny(sid, "./") { - return nil, nil +func (fp *FileProvider) SessionRead(ctx context.Context, sid string) (Store, error) { + invalidChars := "./" + if strings.ContainsAny(sid, invalidChars) { + return nil, errors.New("the sid shouldn't have following characters: " + invalidChars) } if len(sid) < 2 { return nil, errors.New("length of the sid is less than 2") @@ -138,25 +148,34 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() - err := os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777) + sessionPath := filepath.Join(fp.savePath, string(sid[0]), string(sid[1])) + sidPath := filepath.Join(sessionPath, sid) + err := os.MkdirAll(sessionPath, 0o755) if err != nil { SLogger.Println(err.Error()) } - _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) var f *os.File - if err == nil { - f, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_RDWR, 0777) - } else if os.IsNotExist(err) { - f, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - } else { + _, err = os.Stat(sidPath) + switch { + case err == nil: + f, err = os.OpenFile(sidPath, os.O_RDWR, 0o777) + if err != nil { + return nil, err + } + case os.IsNotExist(err): + f, err = os.Create(sidPath) + if err != nil { + return nil, err + } + default: return nil, err } defer f.Close() - os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) + os.Chtimes(sidPath, time.Now(), time.Now()) var kv map[interface{}]interface{} - b, err := ioutil.ReadAll(f) + b, err := io.ReadAll(f) if err != nil { return nil, err } @@ -175,24 +194,29 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { // SessionExist Check file session exist. // it checks the file named from sid exist or not. -func (fp *FileProvider) SessionExist(sid string) bool { +func (fp *FileProvider) SessionExist(ctx context.Context, sid string) (bool, error) { filepder.lock.Lock() defer filepder.lock.Unlock() - _, err := os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) - return err == nil + if len(sid) < 2 { + SLogger.Println("min length of session id is 2 but got length: ", sid) + return false, errors.New("min length of session id is 2") + } + + _, err := os.Stat(filepath.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + return err == nil, nil } // SessionDestroy Remove all files in this save path -func (fp *FileProvider) SessionDestroy(sid string) error { +func (fp *FileProvider) SessionDestroy(ctx context.Context, sid string) error { filepder.lock.Lock() defer filepder.lock.Unlock() - os.Remove(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) + os.Remove(filepath.Join(fp.savePath, string(sid[0]), string(sid[1]), sid)) return nil } // SessionGC Recycle files in save path -func (fp *FileProvider) SessionGC() { +func (fp *FileProvider) SessionGC(context.Context) { filepder.lock.Lock() defer filepder.lock.Unlock() @@ -202,11 +226,9 @@ func (fp *FileProvider) SessionGC() { // SessionAll Get active file session number. // it walks save path to count files. -func (fp *FileProvider) SessionAll() int { +func (fp *FileProvider) SessionAll(context.Context) int { a := &activeSession{} - err := filepath.Walk(fp.savePath, func(path string, f os.FileInfo, err error) error { - return a.visit(path, f, err) - }) + err := filepath.Walk(fp.savePath, a.visit) if err != nil { SLogger.Printf("filepath.Walk() returned %v\n", err) return 0 @@ -215,15 +237,15 @@ func (fp *FileProvider) SessionAll() int { } // SessionRegenerate Generate new sid for file session. -// it delete old file and create new file named from new sid. -func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +// it deletes old file and create new file named from new sid. +func (fp *FileProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { filepder.lock.Lock() defer filepder.lock.Unlock() - oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) - oldSidFile := path.Join(oldPath, oldsid) - newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1])) - newSidFile := path.Join(newPath, sid) + oldPath := filepath.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])) + oldSidFile := filepath.Join(oldPath, oldsid) + newPath := filepath.Join(fp.savePath, string(sid[0]), string(sid[1])) + newSidFile := filepath.Join(newPath, sid) // new sid file is exist _, err := os.Stat(newSidFile) @@ -231,7 +253,7 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { return nil, fmt.Errorf("newsid %s exist", newSidFile) } - err = os.MkdirAll(newPath, 0777) + err = os.MkdirAll(newPath, 0o755) if err != nil { SLogger.Println(err.Error()) } @@ -243,7 +265,7 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { // 4.return FileSessionStore _, err = os.Stat(oldSidFile) if err == nil { - b, err := ioutil.ReadFile(oldSidFile) + b, err := os.ReadFile(oldSidFile) if err != nil { return nil, err } @@ -258,7 +280,7 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { } } - ioutil.WriteFile(newSidFile, b, 0777) + os.WriteFile(newSidFile, b, 0o777) os.Remove(oldSidFile) os.Chtimes(newSidFile, time.Now(), time.Now()) ss := &FileSessionStore{sid: sid, values: kv} diff --git a/server/web/session/sess_file_test.go b/server/web/session/sess_file_test.go new file mode 100644 index 0000000000..f944e795d8 --- /dev/null +++ b/server/web/session/sess_file_test.go @@ -0,0 +1,474 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "fmt" + "net/http/httptest" + "os" + "sync" + "testing" + "time" +) + +const ( + sid = "Session_id" + sidNew = "Session_id_new" + sessionPath = "./_session_runtime" +) + +var mutex sync.Mutex + +func TestFileProviderSessionInit(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + if fp.maxlifetime != 180 { + t.Error() + } + + if fp.savePath != sessionPath { + t.Error() + } +} + +func TestFileProviderSessionExist(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + exists, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exists { + t.Error() + } + + _, err = fp.SessionRead(context.Background(), sid) + if err != nil { + t.Error(err) + } + + exists, err = fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if !exists { + t.Error() + } +} + +func TestFileProviderSessionExist2(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + exists, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exists { + t.Error() + } + + exists, err = fp.SessionExist(context.Background(), "") + if err == nil { + t.Error() + } + if exists { + t.Error() + } + + exists, err = fp.SessionExist(context.Background(), "1") + if err == nil { + t.Error() + } + if exists { + t.Error() + } +} + +func TestFileProviderSessionRead(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + s, err := fp.SessionRead(context.Background(), sid) + if err != nil { + t.Error(err) + } + + _ = s.Set(nil, "sessionValue", 18975) + v := s.Get(nil, "sessionValue") + + if v.(int) != 18975 { + t.Error() + } +} + +func TestFileProviderSessionRead1(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + _, err := fp.SessionRead(context.Background(), "") + if err == nil { + t.Error(err) + } + + _, err = fp.SessionRead(context.Background(), "1") + if err == nil { + t.Error(err) + } +} + +func TestFileProviderSessionAll(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 546 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + if fp.SessionAll(nil) != sessionCount { + t.Error() + } +} + +func TestFileProviderSessionRegenerate(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + _, err := fp.SessionRead(context.Background(), sid) + if err != nil { + t.Error(err) + } + + exists, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if !exists { + t.Error() + } + + _, err = fp.SessionRegenerate(context.Background(), sid, sidNew) + if err != nil { + t.Error(err) + } + + exists, err = fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exists { + t.Error() + } + + exists, err = fp.SessionExist(context.Background(), sidNew) + if err != nil { + t.Error(err) + } + if !exists { + t.Error() + } +} + +func TestFileProviderSessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + _, err := fp.SessionRead(context.Background(), sid) + if err != nil { + t.Error(err) + } + + exists, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if !exists { + t.Error() + } + + err = fp.SessionDestroy(context.Background(), sid) + if err != nil { + t.Error(err) + } + + exists, err = fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exists { + t.Error() + } +} + +func TestFileProviderSessionGC(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 1, sessionPath) + + sessionCount := 412 + + for i := 1; i <= sessionCount; i++ { + _, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + } + + time.Sleep(2 * time.Second) + + fp.SessionGC(nil) + if fp.SessionAll(nil) != 0 { + t.Error() + } +} + +func TestFileSessionStoreSet(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(context.Background(), sid) + for i := 1; i <= sessionCount; i++ { + err := s.Set(nil, i, i) + if err != nil { + t.Error(err) + } + } +} + +func TestFileSessionStoreGet(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(context.Background(), sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(nil, i, i) + + v := s.Get(nil, i) + if v.(int) != i { + t.Error() + } + } +} + +func TestFileSessionStoreDelete(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + s, _ := fp.SessionRead(context.Background(), sid) + s.Set(context.Background(), "1", 1) + + if s.Get(context.Background(), "1") == nil { + t.Error() + } + + s.Delete(context.Background(), "1") + + if s.Get(context.Background(), "1") != nil { + t.Error() + } +} + +func TestFileSessionStoreFlush(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 100 + s, _ := fp.SessionRead(context.Background(), sid) + for i := 1; i <= sessionCount; i++ { + _ = s.Set(nil, i, i) + } + + _ = s.Flush(nil) + + for i := 1; i <= sessionCount; i++ { + if s.Get(nil, i) != nil { + t.Error() + } + } +} + +func TestFileSessionStoreSessionID(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + if s.SessionID(context.Background()) != fmt.Sprintf("%s_%d", sid, i) { + t.Error(err) + } + } +} + +func TestFileSessionStoreSessionRelease(t *testing.T) { + releaseSession(t, false) +} + +func TestFileSessionStoreSessionReleaseIfPresent(t *testing.T) { + releaseSession(t, true) +} + +func releaseSession(t *testing.T, requirePresent bool) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + filepder.savePath = sessionPath + sessionCount := 85 + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + s.Set(context.Background(), i, i) + if requirePresent { + s.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder()) + } else { + s.SessionRelease(context.Background(), httptest.NewRecorder()) + } + + } + + for i := 1; i <= sessionCount; i++ { + s, err := fp.SessionRead(context.Background(), fmt.Sprintf("%s_%d", sid, i)) + if err != nil { + t.Error(err) + } + + if s.Get(nil, i).(int) != i { + t.Error() + } + } +} + +func TestFileSessionStoreSessionReleaseIfPresentAndSessionDestroy(t *testing.T) { + mutex.Lock() + defer mutex.Unlock() + os.RemoveAll(sessionPath) + defer os.RemoveAll(sessionPath) + fp := &FileProvider{} + s, err := fp.SessionRead(context.Background(), sid) + if err != nil { + return + } + + _ = fp.SessionInit(context.Background(), 180, sessionPath) + filepder.savePath = sessionPath + if err := fp.SessionDestroy(context.Background(), sid); err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + s.SessionReleaseIfPresent(context.Background(), httptest.NewRecorder()) + }() + wg.Wait() + exist, err := fp.SessionExist(context.Background(), sid) + if err != nil { + t.Error(err) + } + if exist { + t.Fatalf("session %s should exist", sid) + } +} diff --git a/session/sess_mem.go b/server/web/session/sess_mem.go similarity index 73% rename from session/sess_mem.go rename to server/web/session/sess_mem.go index 64d8b05617..6f41f7cf7b 100644 --- a/session/sess_mem.go +++ b/server/web/session/sess_mem.go @@ -16,6 +16,7 @@ package session import ( "container/list" + "context" "net/http" "sync" "time" @@ -26,14 +27,14 @@ var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Ele // MemSessionStore memory session store. // it saved sessions in a map in memory. type MemSessionStore struct { - sid string //session id - timeAccessed time.Time //last access time - value map[interface{}]interface{} //session store + sid string // session id + timeAccessed time.Time // last access time + value map[interface{}]interface{} // session store lock sync.RWMutex } // Set value to memory session -func (st *MemSessionStore) Set(key, value interface{}) error { +func (st *MemSessionStore) Set(ctx context.Context, key, value interface{}) error { st.lock.Lock() defer st.lock.Unlock() st.value[key] = value @@ -41,7 +42,7 @@ func (st *MemSessionStore) Set(key, value interface{}) error { } // Get value from memory session by key -func (st *MemSessionStore) Get(key interface{}) interface{} { +func (st *MemSessionStore) Get(ctx context.Context, key interface{}) interface{} { st.lock.RLock() defer st.lock.RUnlock() if v, ok := st.value[key]; ok { @@ -51,7 +52,7 @@ func (st *MemSessionStore) Get(key interface{}) interface{} { } // Delete in memory session by key -func (st *MemSessionStore) Delete(key interface{}) error { +func (st *MemSessionStore) Delete(ctx context.Context, key interface{}) error { st.lock.Lock() defer st.lock.Unlock() delete(st.value, key) @@ -59,7 +60,7 @@ func (st *MemSessionStore) Delete(key interface{}) error { } // Flush clear all values in memory session -func (st *MemSessionStore) Flush() error { +func (st *MemSessionStore) Flush(context.Context) error { st.lock.Lock() defer st.lock.Unlock() st.value = make(map[interface{}]interface{}) @@ -67,12 +68,16 @@ func (st *MemSessionStore) Flush() error { } // SessionID get this id of memory session store -func (st *MemSessionStore) SessionID() string { +func (st *MemSessionStore) SessionID(context.Context) string { return st.sid } // SessionRelease Implement method, no used. -func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) { +func (st *MemSessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { +} + +// SessionReleaseIfPresent Implement method, no used. +func (*MemSessionStore) SessionReleaseIfPresent(_ context.Context, _ http.ResponseWriter) { } // MemProvider Implement the provider interface @@ -85,17 +90,17 @@ type MemProvider struct { } // SessionInit init memory session -func (pder *MemProvider) SessionInit(maxlifetime int64, savePath string) error { +func (pder *MemProvider) SessionInit(ctx context.Context, maxlifetime int64, savePath string) error { pder.maxlifetime = maxlifetime pder.savePath = savePath return nil } // SessionRead get memory session store by sid -func (pder *MemProvider) SessionRead(sid string) (Store, error) { +func (pder *MemProvider) SessionRead(ctx context.Context, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[sid]; ok { - go pder.SessionUpdate(sid) + go pder.SessionUpdate(nil, sid) pder.lock.RUnlock() return element.Value.(*MemSessionStore), nil } @@ -109,20 +114,20 @@ func (pder *MemProvider) SessionRead(sid string) (Store, error) { } // SessionExist check session store exist in memory session by sid -func (pder *MemProvider) SessionExist(sid string) bool { +func (pder *MemProvider) SessionExist(ctx context.Context, sid string) (bool, error) { pder.lock.RLock() defer pder.lock.RUnlock() if _, ok := pder.sessions[sid]; ok { - return true + return true, nil } - return false + return false, nil } // SessionRegenerate generate new sid for session store in memory session -func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { +func (pder *MemProvider) SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) { pder.lock.RLock() if element, ok := pder.sessions[oldsid]; ok { - go pder.SessionUpdate(oldsid) + go pder.SessionUpdate(nil, oldsid) pder.lock.RUnlock() pder.lock.Lock() element.Value.(*MemSessionStore).sid = sid @@ -141,7 +146,7 @@ func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (Store, error) { } // SessionDestroy delete session store in memory session by id -func (pder *MemProvider) SessionDestroy(sid string) error { +func (pder *MemProvider) SessionDestroy(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { @@ -153,7 +158,7 @@ func (pder *MemProvider) SessionDestroy(sid string) error { } // SessionGC clean expired session stores in memory session -func (pder *MemProvider) SessionGC() { +func (pder *MemProvider) SessionGC(context.Context) { pder.lock.RLock() for { element := pder.list.Back() @@ -175,12 +180,12 @@ func (pder *MemProvider) SessionGC() { } // SessionAll get count number of memory session -func (pder *MemProvider) SessionAll() int { +func (pder *MemProvider) SessionAll(context.Context) int { return pder.list.Len() } // SessionUpdate expand time of session store by id in memory session -func (pder *MemProvider) SessionUpdate(sid string) error { +func (pder *MemProvider) SessionUpdate(ctx context.Context, sid string) error { pder.lock.Lock() defer pder.lock.Unlock() if element, ok := pder.sessions[sid]; ok { diff --git a/session/sess_mem_test.go b/server/web/session/sess_mem_test.go similarity index 91% rename from session/sess_mem_test.go rename to server/web/session/sess_mem_test.go index 2e8934b825..e6d3547618 100644 --- a/session/sess_mem_test.go +++ b/server/web/session/sess_mem_test.go @@ -36,12 +36,12 @@ func TestMem(t *testing.T) { if err != nil { t.Fatal("set error,", err) } - defer sess.SessionRelease(w) - err = sess.Set("username", "astaxie") + defer sess.SessionRelease(nil, w) + err = sess.Set(nil, "username", "astaxie") if err != nil { t.Fatal("set error,", err) } - if username := sess.Get("username"); username != "astaxie" { + if username := sess.Get(nil, "username"); username != "astaxie" { t.Fatal("get username error") } if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" { diff --git a/session/sess_test.go b/server/web/session/sess_test.go similarity index 100% rename from session/sess_test.go rename to server/web/session/sess_test.go diff --git a/session/sess_utils.go b/server/web/session/sess_utils.go similarity index 96% rename from session/sess_utils.go rename to server/web/session/sess_utils.go index 2e3376c714..b26493367f 100644 --- a/session/sess_utils.go +++ b/server/web/session/sess_utils.go @@ -19,7 +19,7 @@ import ( "crypto/cipher" "crypto/hmac" "crypto/rand" - "crypto/sha1" + "crypto/sha256" "crypto/subtle" "encoding/base64" "encoding/gob" @@ -29,7 +29,7 @@ import ( "strconv" "time" - "github.com/astaxie/beego/utils" + "github.com/beego/beego/v2/core/utils" ) func init() { @@ -98,7 +98,7 @@ func encrypt(block cipher.Block, value []byte) ([]byte, error) { // decrypt decrypts a value using the given block in counter mode. // -// The value to be decrypted must be prepended by a initialization vector +// The value to be decrypted must be prepended by an initialization vector // (http://goo.gl/zF67k) with the length of the block size. func decrypt(block cipher.Block, value []byte) ([]byte, error) { size := block.BlockSize() @@ -129,7 +129,7 @@ func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{ b = encode(b) // 3. Create MAC for "name|date|value". Extra pipe to be used later. b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b)) - h := hmac.New(sha1.New, []byte(hashKey)) + h := hmac.New(sha256.New, []byte(hashKey)) h.Write(b) sig := h.Sum(nil) // Append mac, remove name. @@ -153,7 +153,7 @@ func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime } b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...) - h := hmac.New(sha1.New, []byte(hashKey)) + h := hmac.New(sha256.New, []byte(hashKey)) h.Write(b) sig := h.Sum(nil) if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 { diff --git a/session/session.go b/server/web/session/session.go similarity index 72% rename from session/session.go rename to server/web/session/session.go index 46a9f1f0d7..9505604dc0 100644 --- a/session/session.go +++ b/server/web/session/session.go @@ -16,18 +16,19 @@ // // Usage: // import( -// "github.com/astaxie/beego/session" -// ) // -// func init() { -// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) -// go globalSessions.GC() -// } +// "github.com/beego/beego/v2/server/web/session" +// +// ) // -// more docs: http://beego.me/docs/module/session.md +// func init() { +// globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid", "enableSetCookie,omitempty": true, "gclifetime":3600, "maxLifetime": 3600, "secure": false, "cookieLifeTime": 3600, "providerConfig": ""}`) +// go globalSessions.GC() +// } package session import ( + "context" "crypto/rand" "encoding/hex" "errors" @@ -43,24 +44,25 @@ import ( // Store contains all data for one session process with specific id. type Store interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data + Set(ctx context.Context, key, value interface{}) error // Set set session value + Get(ctx context.Context, key interface{}) interface{} // Get get session value + Delete(ctx context.Context, key interface{}) error // Delete delete session value + SessionID(ctx context.Context) string // SessionID return current sessionID + SessionReleaseIfPresent(ctx context.Context, w http.ResponseWriter) // SessionReleaseIfPresent release the resource & save data to provider & return the data when the session is present, not all implementation support this feature, you need to check if the specific implementation if support this feature. + SessionRelease(ctx context.Context, w http.ResponseWriter) // SessionRelease release the resource & save data to provider & return the data + Flush(ctx context.Context) error // Flush delete all data } // Provider contains global session methods and saved SessionStores. // it can operate a SessionStore by its id. type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (Store, error) - SessionExist(sid string) bool - SessionRegenerate(oldsid, sid string) (Store, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() + SessionInit(ctx context.Context, gclifetime int64, config string) error + SessionRead(ctx context.Context, sid string) (Store, error) + SessionExist(ctx context.Context, sid string) (bool, error) + SessionRegenerate(ctx context.Context, oldsid, sid string) (Store, error) + SessionDestroy(ctx context.Context, sid string) error + SessionAll(ctx context.Context) int // get all active session + SessionGC(ctx context.Context) } var provides = make(map[string]Provider) @@ -69,19 +71,15 @@ var provides = make(map[string]Provider) var SLogger = NewSessionLog(os.Stderr) // Register makes a session provide available by the provided name. -// If Register is called twice with the same name or if driver is nil, -// it panics. +// If provider is nil, it panic func Register(name string, provide Provider) { if provide == nil { panic("session: Register provide is nil") } - if _, dup := provides[name]; dup { - panic("session: Register called twice for provider " + name) - } provides[name] = provide } -//GetProvider +// GetProvider func GetProvider(name string) (Provider, error) { provider, ok := provides[name] if !ok { @@ -90,24 +88,6 @@ func GetProvider(name string) (Provider, error) { return provider, nil } -// ManagerConfig define the session config -type ManagerConfig struct { - CookieName string `json:"cookieName"` - EnableSetCookie bool `json:"enableSetCookie,omitempty"` - Gclifetime int64 `json:"gclifetime"` - Maxlifetime int64 `json:"maxLifetime"` - DisableHTTPOnly bool `json:"disableHTTPOnly"` - Secure bool `json:"secure"` - CookieLifeTime int `json:"cookieLifeTime"` - ProviderConfig string `json:"providerConfig"` - Domain string `json:"domain"` - SessionIDLength int64 `json:"sessionIDLength"` - EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` - SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` - EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` - SessionIDPrefix string `json:"sessionIDPrefix"` -} - // Manager contains Provider and its configuration. type Manager struct { provider Provider @@ -148,7 +128,7 @@ func NewManager(provideName string, cf *ManagerConfig) (*Manager, error) { } } - err := provider.SessionInit(cf.Maxlifetime, cf.ProviderConfig) + err := provider.SessionInit(context.Background(), cf.Maxlifetime, cf.ProviderConfig) if err != nil { return nil, err } @@ -174,7 +154,7 @@ func (manager *Manager) GetProvider() Provider { // // error is not nil when there is anything wrong. // sid is empty when need to generate a new session id -// otherwise return an valid session id. +// otherwise return a valid session id. func (manager *Manager) getSid(r *http.Request) (string, error) { cookie, errs := r.Cookie(manager.config.CookieName) if errs != nil || cookie.Value == "" { @@ -211,8 +191,14 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se return nil, errs } - if sid != "" && manager.provider.SessionExist(sid) { - return manager.provider.SessionRead(sid) + if sid != "" { + exists, err := manager.provider.SessionExist(context.Background(), sid) + if err != nil { + return nil, err + } + if exists { + return manager.provider.SessionRead(context.Background(), sid) + } } // Generate a new session @@ -221,7 +207,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se return nil, errs } - session, err = manager.provider.SessionRead(sid) + session, err = manager.provider.SessionRead(context.Background(), sid) if err != nil { return nil, err } @@ -232,6 +218,7 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se HttpOnly: !manager.config.DisableHTTPOnly, Secure: manager.isSecure(r), Domain: manager.config.Domain, + SameSite: manager.config.CookieSameSite, } if manager.config.CookieLifeTime > 0 { cookie.MaxAge = manager.config.CookieLifeTime @@ -263,14 +250,18 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { } sid, _ := url.QueryUnescape(cookie.Value) - manager.provider.SessionDestroy(sid) + manager.provider.SessionDestroy(context.Background(), sid) if manager.config.EnableSetCookie { expiration := time.Now() - cookie = &http.Cookie{Name: manager.config.CookieName, + cookie = &http.Cookie{ + Name: manager.config.CookieName, Path: "/", HttpOnly: !manager.config.DisableHTTPOnly, Expires: expiration, - MaxAge: -1} + MaxAge: -1, + Domain: manager.config.Domain, + SameSite: manager.config.CookieSameSite, + } http.SetCookie(w, cookie) } @@ -278,61 +269,72 @@ func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { // GetSessionStore Get SessionStore by its id. func (manager *Manager) GetSessionStore(sid string) (sessions Store, err error) { - sessions, err = manager.provider.SessionRead(sid) + sessions, err = manager.provider.SessionRead(context.Background(), sid) return } // GC Start session gc process. // it can do gc in times after gc lifetime. func (manager *Manager) GC() { - manager.provider.SessionGC() + manager.provider.SessionGC(nil) time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() }) } -// SessionRegenerateID Regenerate a session id for this SessionStore who's id is saving in http request. -func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (session Store) { +// SessionRegenerateID Regenerate a session id for this SessionStore whose id is saving in http request. +func (manager *Manager) SessionRegenerateID(w http.ResponseWriter, r *http.Request) (Store, error) { sid, err := manager.sessionID() if err != nil { - return + return nil, err } + + var session Store cookie, err := r.Cookie(manager.config.CookieName) if err != nil || cookie.Value == "" { - //delete old cookie - session, _ = manager.provider.SessionRead(sid) - cookie = &http.Cookie{Name: manager.config.CookieName, - Value: url.QueryEscape(sid), - Path: "/", - HttpOnly: !manager.config.DisableHTTPOnly, - Secure: manager.isSecure(r), - Domain: manager.config.Domain, + // delete old cookie + session, err = manager.provider.SessionRead(context.Background(), sid) + if err != nil { + return nil, err + } + cookie = &http.Cookie{ + Name: manager.config.CookieName, + Value: url.QueryEscape(sid), } } else { - oldsid, _ := url.QueryUnescape(cookie.Value) - session, _ = manager.provider.SessionRegenerate(oldsid, sid) + oldsid, err := url.QueryUnescape(cookie.Value) + if err != nil { + return nil, err + } + session, err = manager.provider.SessionRegenerate(context.Background(), oldsid, sid) + if err != nil { + return nil, err + } cookie.Value = url.QueryEscape(sid) - cookie.HttpOnly = true - cookie.Path = "/" } if manager.config.CookieLifeTime > 0 { cookie.MaxAge = manager.config.CookieLifeTime cookie.Expires = time.Now().Add(time.Duration(manager.config.CookieLifeTime) * time.Second) } + + cookie.HttpOnly = !manager.config.DisableHTTPOnly + cookie.Path = "/" + cookie.Secure = manager.isSecure(r) + cookie.Domain = manager.config.Domain + cookie.SameSite = manager.config.CookieSameSite + if manager.config.EnableSetCookie { http.SetCookie(w, cookie) } r.AddCookie(cookie) - if manager.config.EnableSidInHTTPHeader { r.Header.Set(manager.config.SessionNameInHTTPHeader, sid) w.Header().Set(manager.config.SessionNameInHTTPHeader, sid) } - - return + return session, nil } // GetActiveSession Get all active sessions count number. func (manager *Manager) GetActiveSession() int { - return manager.provider.SessionAll() + return manager.provider.SessionAll(nil) } // SetSecure Set cookie with https. diff --git a/server/web/session/session_config.go b/server/web/session/session_config.go new file mode 100644 index 0000000000..65d36dde6b --- /dev/null +++ b/server/web/session/session_config.go @@ -0,0 +1,143 @@ +package session + +import "net/http" + +// ManagerConfig define the session config +type ManagerConfig struct { + EnableSetCookie bool `json:"enableSetCookie,omitempty"` + DisableHTTPOnly bool `json:"disableHTTPOnly"` + Secure bool `json:"secure"` + EnableSidInHTTPHeader bool `json:"EnableSidInHTTPHeader"` + EnableSidInURLQuery bool `json:"EnableSidInURLQuery"` + CookieName string `json:"cookieName"` + Gclifetime int64 `json:"gclifetime"` + Maxlifetime int64 `json:"maxLifetime"` + CookieLifeTime int `json:"cookieLifeTime"` + ProviderConfig string `json:"providerConfig"` + Domain string `json:"domain"` + SessionIDLength int64 `json:"sessionIDLength"` + SessionNameInHTTPHeader string `json:"SessionNameInHTTPHeader"` + SessionIDPrefix string `json:"sessionIDPrefix"` + CookieSameSite http.SameSite `json:"cookieSameSite"` +} + +func (c *ManagerConfig) Opts(opts ...ManagerConfigOpt) { + for _, opt := range opts { + opt(c) + } +} + +type ManagerConfigOpt func(config *ManagerConfig) + +func NewManagerConfig(opts ...ManagerConfigOpt) *ManagerConfig { + config := &ManagerConfig{} + for _, opt := range opts { + opt(config) + } + return config +} + +// CfgCookieName set key of session id +func CfgCookieName(cookieName string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieName = cookieName + } +} + +// CfgSessionIdLength set len of session id +func CfgSessionIdLength(length int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDLength = length + } +} + +// CfgSessionIdPrefix set prefix of session id +func CfgSessionIdPrefix(prefix string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionIDPrefix = prefix + } +} + +// CfgSetCookie whether set `Set-Cookie` header in HTTP response +func CfgSetCookie(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSetCookie = enable + } +} + +// CfgGcLifeTime set session gc lift time +func CfgGcLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Gclifetime = lifeTime + } +} + +// CfgMaxLifeTime set session lift time +func CfgMaxLifeTime(lifeTime int64) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Maxlifetime = lifeTime + } +} + +// CfgCookieLifeTime set cookie lift time +func CfgCookieLifeTime(lifeTime int) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieLifeTime = lifeTime + } +} + +// CfgProviderConfig configure session provider +func CfgProviderConfig(providerConfig string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.ProviderConfig = providerConfig + } +} + +// CfgDomain set cookie domain +func CfgDomain(domain string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Domain = domain + } +} + +// CfgSessionIdInHTTPHeader enable session id in http header +func CfgSessionIdInHTTPHeader(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInHTTPHeader = enable + } +} + +// CfgSetSessionNameInHTTPHeader set key of session id in http header +func CfgSetSessionNameInHTTPHeader(name string) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.SessionNameInHTTPHeader = name + } +} + +// EnableSidInURLQuery enable session id in query string +func CfgEnableSidInURLQuery(enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.EnableSidInURLQuery = enable + } +} + +// DisableHTTPOnly set HTTPOnly for http.Cookie +func CfgHTTPOnly(HTTPOnly bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.DisableHTTPOnly = !HTTPOnly + } +} + +// CfgSecure set Secure for http.Cookie +func CfgSecure(Enable bool) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.Secure = Enable + } +} + +// CfgSameSite set http.SameSite +func CfgSameSite(sameSite http.SameSite) ManagerConfigOpt { + return func(config *ManagerConfig) { + config.CookieSameSite = sameSite + } +} diff --git a/server/web/session/session_config_test.go b/server/web/session/session_config_test.go new file mode 100644 index 0000000000..0ea7d22b48 --- /dev/null +++ b/server/web/session/session_config_test.go @@ -0,0 +1,222 @@ +package session + +import ( + "net/http" + "testing" +) + +func TestCfgCookieLifeTime(t *testing.T) { + value := 8754 + c := NewManagerConfig( + CfgCookieLifeTime(value), + ) + + if c.CookieLifeTime != value { + t.Error() + } +} + +func TestCfgDomain(t *testing.T) { + value := `http://domain.com` + c := NewManagerConfig( + CfgDomain(value), + ) + + if c.Domain != value { + t.Error() + } +} + +func TestCfgSameSite(t *testing.T) { + value := http.SameSiteLaxMode + c := NewManagerConfig( + CfgSameSite(value), + ) + + if c.CookieSameSite != value { + t.Error() + } +} + +func TestCfgSecure(t *testing.T) { + c := NewManagerConfig( + CfgSecure(true), + ) + + if c.Secure != true { + t.Error() + } +} + +func TestCfgSecure1(t *testing.T) { + c := NewManagerConfig( + CfgSecure(false), + ) + + if c.Secure != false { + t.Error() + } +} + +func TestCfgSessionIdPrefix(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSessionIdPrefix(value), + ) + + if c.SessionIDPrefix != value { + t.Error() + } +} + +func TestCfgSetSessionNameInHTTPHeader(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgSetSessionNameInHTTPHeader(value), + ) + + if c.SessionNameInHTTPHeader != value { + t.Error() + } +} + +func TestCfgCookieName(t *testing.T) { + value := `sodiausodkljalsd` + c := NewManagerConfig( + CfgCookieName(value), + ) + + if c.CookieName != value { + t.Error() + } +} + +func TestCfgEnableSidInURLQuery(t *testing.T) { + c := NewManagerConfig( + CfgEnableSidInURLQuery(true), + ) + + if c.EnableSidInURLQuery != true { + t.Error() + } +} + +func TestCfgGcLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgGcLifeTime(value), + ) + + if c.Gclifetime != value { + t.Error() + } +} + +func TestCfgHTTPOnly(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(true), + ) + + if c.DisableHTTPOnly != false { + t.Error() + } +} + +func TestCfgHTTPOnly2(t *testing.T) { + c := NewManagerConfig( + CfgHTTPOnly(false), + ) + + if c.DisableHTTPOnly != true { + t.Error() + } +} + +func TestCfgMaxLifeTime(t *testing.T) { + value := int64(5454) + c := NewManagerConfig( + CfgMaxLifeTime(value), + ) + + if c.Maxlifetime != value { + t.Error() + } +} + +func TestCfgProviderConfig(t *testing.T) { + value := `asodiuasldkj12i39809as` + c := NewManagerConfig( + CfgProviderConfig(value), + ) + + if c.ProviderConfig != value { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(true), + ) + + if c.EnableSidInHTTPHeader != true { + t.Error() + } +} + +func TestCfgSessionIdInHTTPHeader1(t *testing.T) { + c := NewManagerConfig( + CfgSessionIdInHTTPHeader(false), + ) + + if c.EnableSidInHTTPHeader != false { + t.Error() + } +} + +func TestCfgSessionIdLength(t *testing.T) { + value := int64(100) + c := NewManagerConfig( + CfgSessionIdLength(value), + ) + + if c.SessionIDLength != value { + t.Error() + } +} + +func TestCfgSetCookie(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(true), + ) + + if c.EnableSetCookie != true { + t.Error() + } +} + +func TestCfgSetCookie1(t *testing.T) { + c := NewManagerConfig( + CfgSetCookie(false), + ) + + if c.EnableSetCookie != false { + t.Error() + } +} + +func TestNewManagerConfig(t *testing.T) { + c := NewManagerConfig() + if c == nil { + t.Error() + } +} + +func TestManagerConfig_Opts(t *testing.T) { + c := NewManagerConfig() + c.Opts(CfgSetCookie(true)) + + if c.EnableSetCookie != true { + t.Error() + } +} diff --git a/server/web/session/session_provider_type.go b/server/web/session/session_provider_type.go new file mode 100644 index 0000000000..78dc116dce --- /dev/null +++ b/server/web/session/session_provider_type.go @@ -0,0 +1,18 @@ +package session + +type ProviderType string + +const ( + ProviderCookie ProviderType = `cookie` + ProviderFile ProviderType = `file` + ProviderMemory ProviderType = `memory` + ProviderCouchbase ProviderType = `couchbase` + ProviderLedis ProviderType = `ledis` + ProviderMemcache ProviderType = `memcache` + ProviderMysql ProviderType = `mysql` + ProviderPostgresql ProviderType = `postgresql` + ProviderRedis ProviderType = `redis` + ProviderRedisCluster ProviderType = `redis_cluster` + ProviderRedisSentinel ProviderType = `redis_sentinel` + ProviderSsdb ProviderType = `ssdb` +) diff --git a/session/ssdb/sess_ssdb.go b/server/web/session/ssdb/sess_ssdb.go similarity index 61% rename from session/ssdb/sess_ssdb.go rename to server/web/session/ssdb/sess_ssdb.go index de0c6360c5..7ce43affab 100644 --- a/session/ssdb/sess_ssdb.go +++ b/server/web/session/ssdb/sess_ssdb.go @@ -1,14 +1,17 @@ package ssdb import ( + "context" + "encoding/json" "errors" "net/http" "strconv" "strings" "sync" - "github.com/astaxie/beego/session" "github.com/ssdb/gossdb/ssdb" + + "github.com/beego/beego/v2/server/web/session" ) var ssdbProvider = &Provider{} @@ -16,35 +19,50 @@ var ssdbProvider = &Provider{} // Provider holds ssdb client and configs type Provider struct { client *ssdb.Client - host string - port int + Host string `json:"host"` + Port int `json:"port"` maxLifetime int64 } func (p *Provider) connectInit() error { var err error - if p.host == "" || p.port == 0 { + if p.Host == "" || p.Port == 0 { return errors.New("SessionInit First") } - p.client, err = ssdb.Connect(p.host, p.port) + p.client, err = ssdb.Connect(p.Host, p.Port) return err } // SessionInit init the ssdb with the config -func (p *Provider) SessionInit(maxLifetime int64, savePath string) error { +func (p *Provider) SessionInit(ctx context.Context, maxLifetime int64, cfg string) error { p.maxLifetime = maxLifetime - address := strings.Split(savePath, ":") - p.host = address[0] + cfg = strings.TrimSpace(cfg) var err error - if p.port, err = strconv.Atoi(address[1]); err != nil { + // we think this is v2.0, using json to init the session + if strings.HasPrefix(cfg, "{") { + err = json.Unmarshal([]byte(cfg), p) + } else { + err = p.initOldStyle(cfg) + } + if err != nil { return err } return p.connectInit() } +// for v1.x +func (p *Provider) initOldStyle(savePath string) error { + address := strings.Split(savePath, ":") + p.Host = address[0] + + var err error + p.Port, err = strconv.Atoi(address[1]) + return err +} + // SessionRead return a ssdb client session Store -func (p *Provider) SessionRead(sid string) (session.Store, error) { +func (p *Provider) SessionRead(ctx context.Context, sid string) (session.Store, error) { if p.client == nil { if err := p.connectInit(); err != nil { return nil, err @@ -67,11 +85,11 @@ func (p *Provider) SessionRead(sid string) (session.Store, error) { return rs, nil } -// SessionExist judged whether sid is exist in session -func (p *Provider) SessionExist(sid string) bool { +// SessionExist judged whether sid existed in session +func (p *Provider) SessionExist(ctx context.Context, sid string) (bool, error) { if p.client == nil { if err := p.connectInit(); err != nil { - panic(err) + return false, err } } value, err := p.client.Get(sid) @@ -79,14 +97,14 @@ func (p *Provider) SessionExist(sid string) bool { panic(err) } if value == nil || len(value.(string)) == 0 { - return false + return false, nil } - return true + return true, nil } // SessionRegenerate regenerate session with new sid and delete oldsid -func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - //conn.Do("setx", key, v, ttl) +func (p *Provider) SessionRegenerate(ctx context.Context, oldsid, sid string) (session.Store, error) { + // conn.Do("setx", key, v, ttl) if p.client == nil { if err := p.connectInit(); err != nil { return nil, err @@ -118,7 +136,7 @@ func (p *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) } // SessionDestroy destroy the sid -func (p *Provider) SessionDestroy(sid string) error { +func (p *Provider) SessionDestroy(ctx context.Context, sid string) error { if p.client == nil { if err := p.connectInit(); err != nil { return err @@ -129,11 +147,11 @@ func (p *Provider) SessionDestroy(sid string) error { } // SessionGC not implemented -func (p *Provider) SessionGC() { +func (p *Provider) SessionGC(context.Context) { } // SessionAll not implemented -func (p *Provider) SessionAll() int { +func (p *Provider) SessionAll(context.Context) int { return 0 } @@ -147,7 +165,7 @@ type SessionStore struct { } // Set the key and value -func (s *SessionStore) Set(key, value interface{}) error { +func (s *SessionStore) Set(ctx context.Context, key, value interface{}) error { s.lock.Lock() defer s.lock.Unlock() s.values[key] = value @@ -155,7 +173,7 @@ func (s *SessionStore) Set(key, value interface{}) error { } // Get return the value by the key -func (s *SessionStore) Get(key interface{}) interface{} { +func (s *SessionStore) Get(ctx context.Context, key interface{}) interface{} { s.lock.Lock() defer s.lock.Unlock() if value, ok := s.values[key]; ok { @@ -165,7 +183,7 @@ func (s *SessionStore) Get(key interface{}) interface{} { } // Delete the key in session store -func (s *SessionStore) Delete(key interface{}) error { +func (s *SessionStore) Delete(ctx context.Context, key interface{}) error { s.lock.Lock() defer s.lock.Unlock() delete(s.values, key) @@ -173,7 +191,7 @@ func (s *SessionStore) Delete(key interface{}) error { } // Flush delete all keys and values -func (s *SessionStore) Flush() error { +func (s *SessionStore) Flush(context.Context) error { s.lock.Lock() defer s.lock.Unlock() s.values = make(map[interface{}]interface{}) @@ -181,19 +199,28 @@ func (s *SessionStore) Flush() error { } // SessionID return the sessionID -func (s *SessionStore) SessionID() string { +func (s *SessionStore) SessionID(context.Context) string { return s.sid } // SessionRelease Store the keyvalues into ssdb -func (s *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(s.values) +func (s *SessionStore) SessionRelease(_ context.Context, _ http.ResponseWriter) { + s.lock.RLock() + values := s.values + s.lock.RUnlock() + b, err := session.EncodeGob(values) if err != nil { return } s.client.Do("setx", s.sid, string(b), s.maxLifetime) } +// SessionReleaseIfPresent is not supported now +// Because ssdb does not support lua script or SETXX command +func (s *SessionStore) SessionReleaseIfPresent(c context.Context, w http.ResponseWriter) { + s.SessionRelease(c, w) +} + func init() { session.Register("ssdb", ssdbProvider) } diff --git a/server/web/session/ssdb/sess_ssdb_test.go b/server/web/session/ssdb/sess_ssdb_test.go new file mode 100644 index 0000000000..3de5da0a17 --- /dev/null +++ b/server/web/session/ssdb/sess_ssdb_test.go @@ -0,0 +1,41 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssdb + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProvider_SessionInit(t *testing.T) { + // using old style + savePath := `localhost:8080` + cp := &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "localhost", cp.Host) + assert.Equal(t, 8080, cp.Port) + assert.Equal(t, int64(12), cp.maxLifetime) + + savePath = ` +{ "host": "localhost", "port": 8080} +` + cp = &Provider{} + cp.SessionInit(context.Background(), 12, savePath) + assert.Equal(t, "localhost", cp.Host) + assert.Equal(t, 8080, cp.Port) + assert.Equal(t, int64(12), cp.maxLifetime) +} diff --git a/staticfile.go b/server/web/staticfile.go similarity index 75% rename from staticfile.go rename to server/web/staticfile.go index 68241a8654..e99cd2970a 100644 --- a/staticfile.go +++ b/server/web/staticfile.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "bytes" @@ -26,8 +26,10 @@ import ( "sync" "time" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/logs" + lru "github.com/hashicorp/golang-lru" + + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web/context" ) var errNotStaticRequest = errors.New("request not a static file request") @@ -37,12 +39,12 @@ func serverStaticRouter(ctx *context.Context) { return } - forbidden, filePath, fileInfo, err := lookupFile(ctx) + fbd, filePath, fileInfo, err := lookupFile(ctx) if err == errNotStaticRequest { return } - if forbidden { + if fbd { exception("403", ctx) return } @@ -63,13 +65,17 @@ func serverStaticRouter(ctx *context.Context) { } ctx.Redirect(302, redirectURL) } else { - //serveFile will list dir + // serveFile will list dir http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) } return + } else if fileInfo.Size() > int64(BConfig.WebConfig.StaticCacheFileSize) { + // over size file serve with http module + http.ServeFile(ctx.ResponseWriter, ctx.Request, filePath) + return } - var enableCompress = BConfig.EnableGzip && isStaticCompress(filePath) + enableCompress := BConfig.EnableGzip && isStaticCompress(filePath) var acceptEncoding string if enableCompress { acceptEncoding = context.ParseEncoding(ctx.Request) @@ -93,10 +99,11 @@ func serverStaticRouter(ctx *context.Context) { } type serveContentHolder struct { - data []byte - modTime time.Time - size int64 - encoding string + data []byte + modTime time.Time + size int64 + originSize int64 // original file size:to judge file changed + encoding string } type serveContentReader struct { @@ -104,22 +111,36 @@ type serveContentReader struct { } var ( - staticFileMap = make(map[string]*serveContentHolder) - mapLock sync.RWMutex + staticFileLruCache *lru.Cache + lruLock sync.RWMutex ) func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, string, *serveContentHolder, *serveContentReader, error) { + if staticFileLruCache == nil { + // avoid lru cache error + if BConfig.WebConfig.StaticCacheFileNum >= 1 { + staticFileLruCache, _ = lru.New(BConfig.WebConfig.StaticCacheFileNum) + } else { + staticFileLruCache, _ = lru.New(1) + } + } mapKey := acceptEncoding + ":" + filePath - mapLock.RLock() - mapFile := staticFileMap[mapKey] - mapLock.RUnlock() + lruLock.RLock() + var mapFile *serveContentHolder + if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { + mapFile = cacheItem.(*serveContentHolder) + } + lruLock.RUnlock() if isOk(mapFile, fi) { reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} return mapFile.encoding != "", mapFile.encoding, mapFile, reader, nil } - mapLock.Lock() - defer mapLock.Unlock() - if mapFile = staticFileMap[mapKey]; !isOk(mapFile, fi) { + lruLock.Lock() + defer lruLock.Unlock() + if cacheItem, ok := staticFileLruCache.Get(mapKey); ok { + mapFile = cacheItem.(*serveContentHolder) + } + if !isOk(mapFile, fi) { file, err := os.Open(filePath) if err != nil { return false, "", nil, nil, err @@ -130,8 +151,10 @@ func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, str if err != nil { return false, "", nil, nil, err } - mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), encoding: n} - staticFileMap[mapKey] = mapFile + mapFile = &serveContentHolder{data: bufferWriter.Bytes(), modTime: fi.ModTime(), size: int64(bufferWriter.Len()), originSize: fi.Size(), encoding: n} + if isOk(mapFile, fi) { + staticFileLruCache.Add(mapKey, mapFile) + } } reader := &serveContentReader{Reader: bytes.NewReader(mapFile.data)} @@ -141,8 +164,10 @@ func openFile(filePath string, fi os.FileInfo, acceptEncoding string) (bool, str func isOk(s *serveContentHolder, fi os.FileInfo) bool { if s == nil { return false + } else if s.size > int64(BConfig.WebConfig.StaticCacheFileSize) { + return false } - return s.modTime == fi.ModTime() && s.size == fi.Size() + return s.modTime == fi.ModTime() && s.originSize == fi.Size() } // isStaticCompress detect static files @@ -178,7 +203,7 @@ func searchFile(ctx *context.Context) (string, os.FileInfo, error) { if !strings.Contains(requestPath, prefix) { continue } - if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { + if prefix != "/" && len(requestPath) > len(prefix) && requestPath[len(prefix)] != '/' { continue } filePath := path.Join(staticDir, requestPath[len(prefix):]) diff --git a/staticfile_test.go b/server/web/staticfile_test.go similarity index 62% rename from staticfile_test.go rename to server/web/staticfile_test.go index 69667bf850..ad2df41a92 100644 --- a/staticfile_test.go +++ b/server/web/staticfile_test.go @@ -1,18 +1,20 @@ -package beego +package web import ( "bytes" "compress/gzip" "compress/zlib" + "fmt" "io" - "io/ioutil" "os" "path/filepath" "testing" ) -var currentWorkDir, _ = os.Getwd() -var licenseFile = filepath.Join(currentWorkDir, "LICENSE") +var ( + currentWorkDir, _ = os.Getwd() + licenseFile = filepath.Join(currentWorkDir, "LICENSE") +) func testOpenFile(encoding string, content []byte, t *testing.T) { fi, _ := os.Stat(licenseFile) @@ -26,9 +28,10 @@ func testOpenFile(encoding string, content []byte, t *testing.T) { assetOpenFileAndContent(sch, reader, content, t) } + func TestOpenStaticFile_1(t *testing.T) { file, _ := os.Open(licenseFile) - content, _ := ioutil.ReadAll(file) + content, _ := io.ReadAll(file) testOpenFile("", content, t) } @@ -38,35 +41,61 @@ func TestOpenStaticFileGzip_1(t *testing.T) { fileWriter, _ := gzip.NewWriterLevel(&zipBuf, gzip.BestCompression) io.Copy(fileWriter, file) fileWriter.Close() - content, _ := ioutil.ReadAll(&zipBuf) + content, _ := io.ReadAll(&zipBuf) testOpenFile("gzip", content, t) } + func TestOpenStaticFileDeflate_1(t *testing.T) { file, _ := os.Open(licenseFile) var zipBuf bytes.Buffer fileWriter, _ := zlib.NewWriterLevel(&zipBuf, zlib.BestCompression) io.Copy(fileWriter, file) fileWriter.Close() - content, _ := ioutil.ReadAll(&zipBuf) + content, _ := io.ReadAll(&zipBuf) testOpenFile("deflate", content, t) } +func TestStaticCacheWork(t *testing.T) { + encodings := []string{"", "gzip", "deflate"} + + fi, _ := os.Stat(licenseFile) + for _, encoding := range encodings { + _, _, first, _, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Error(err) + continue + } + + _, _, second, _, err := openFile(licenseFile, fi, encoding) + if err != nil { + t.Error(err) + continue + } + + address1 := fmt.Sprintf("%p", first) + address2 := fmt.Sprintf("%p", second) + if address1 != address2 { + t.Errorf("encoding '%v' can not hit cache", encoding) + } + } +} + func assetOpenFileAndContent(sch *serveContentHolder, reader *serveContentReader, content []byte, t *testing.T) { t.Log(sch.size, len(content)) if sch.size != int64(len(content)) { t.Log("static content file size not same") t.Fail() } - bs, _ := ioutil.ReadAll(reader) + bs, _ := io.ReadAll(reader) for i, v := range content { if v != bs[i] { t.Log("content not same") t.Fail() } } - if len(staticFileMap) == 0 { + if staticFileLruCache.Len() == 0 { t.Log("men map is empty") t.Fail() } diff --git a/toolbox/statistics.go b/server/web/statistics.go similarity index 79% rename from toolbox/statistics.go rename to server/web/statistics.go index d014544c33..d0f3ebd9d7 100644 --- a/toolbox/statistics.go +++ b/server/web/statistics.go @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package toolbox +package web import ( "fmt" + "html/template" "sync" "time" + + "github.com/beego/beego/v2/core/utils" ) // Statistics struct @@ -33,7 +36,7 @@ type Statistics struct { // URLMap contains several statistics struct to log different data type URLMap struct { lock sync.RWMutex - LengthLimit int //limit the urlmap's length if it's equal to 0 there's no limit + LengthLimit int // limit the urlmap's length if it's equal to 0 there's no limit urlmap map[string]map[string]*Statistics } @@ -63,7 +66,6 @@ func (m *URLMap) AddStatistics(requestMethod, requestURL, requestController stri } m.urlmap[requestURL][requestMethod] = nb } - } else { if m.LengthLimit > 0 && m.LengthLimit <= len(m.urlmap) { return @@ -87,7 +89,7 @@ func (m *URLMap) GetMap() map[string]interface{} { m.lock.RLock() defer m.lock.RUnlock() - var fields = []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} + fields := []string{"requestUrl", "method", "times", "used", "max used", "min used", "avg used"} var resultLists [][]string content := make(map[string]interface{}) @@ -96,17 +98,17 @@ func (m *URLMap) GetMap() map[string]interface{} { for k, v := range m.urlmap { for kk, vv := range v { result := []string{ - fmt.Sprintf("% -50s", k), + fmt.Sprintf("% -50s", template.HTMLEscapeString(k)), fmt.Sprintf("% -10s", kk), fmt.Sprintf("% -16d", vv.RequestNum), fmt.Sprintf("%d", vv.TotalTime), - fmt.Sprintf("% -16s", toS(vv.TotalTime)), + fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.TotalTime)), fmt.Sprintf("%d", vv.MaxTime), - fmt.Sprintf("% -16s", toS(vv.MaxTime)), + fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.MaxTime)), fmt.Sprintf("%d", vv.MinTime), - fmt.Sprintf("% -16s", toS(vv.MinTime)), + fmt.Sprintf("% -16s", utils.ToShortTimeFormat(vv.MinTime)), fmt.Sprintf("%d", time.Duration(int64(vv.TotalTime)/vv.RequestNum)), - fmt.Sprintf("% -16s", toS(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), + fmt.Sprintf("% -16s", utils.ToShortTimeFormat(time.Duration(int64(vv.TotalTime)/vv.RequestNum))), } resultLists = append(resultLists, result) } @@ -117,8 +119,8 @@ func (m *URLMap) GetMap() map[string]interface{} { // GetMapData return all mapdata func (m *URLMap) GetMapData() []map[string]interface{} { - m.lock.Lock() - defer m.lock.Unlock() + m.lock.RLock() + defer m.lock.RUnlock() var resultLists []map[string]interface{} @@ -128,10 +130,10 @@ func (m *URLMap) GetMapData() []map[string]interface{} { "request_url": k, "method": kk, "times": vv.RequestNum, - "total_time": toS(vv.TotalTime), - "max_time": toS(vv.MaxTime), - "min_time": toS(vv.MinTime), - "avg_time": toS(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), + "total_time": utils.ToShortTimeFormat(vv.TotalTime), + "max_time": utils.ToShortTimeFormat(vv.MaxTime), + "min_time": utils.ToShortTimeFormat(vv.MinTime), + "avg_time": utils.ToShortTimeFormat(time.Duration(int64(vv.TotalTime) / vv.RequestNum)), } resultLists = append(resultLists, result) } diff --git a/toolbox/statistics_test.go b/server/web/statistics_test.go similarity index 98% rename from toolbox/statistics_test.go rename to server/web/statistics_test.go index ac29476c0a..7c83e15a06 100644 --- a/toolbox/statistics_test.go +++ b/server/web/statistics_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package toolbox +package web import ( "encoding/json" diff --git a/swagger/swagger.go b/server/web/swagger/swagger.go similarity index 99% rename from swagger/swagger.go rename to server/web/swagger/swagger.go index a55676cdcf..c20b31ed6a 100644 --- a/swagger/swagger.go +++ b/server/web/swagger/swagger.go @@ -106,7 +106,7 @@ type Parameter struct { type ParameterItems struct { Type string `json:"type,omitempty" yaml:"type,omitempty"` Format string `json:"format,omitempty" yaml:"format,omitempty"` - Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` //Required if type is "array". Describes the type of items in the array. + Items []*ParameterItems `json:"items,omitempty" yaml:"items,omitempty"` // Required if type is "array". Describes the type of items in the array. CollectionFormat string `json:"collectionFormat,omitempty" yaml:"collectionFormat,omitempty"` Default string `json:"default,omitempty" yaml:"default,omitempty"` } diff --git a/template.go b/server/web/template.go similarity index 93% rename from template.go rename to server/web/template.go index 59875be7b3..304689cbe4 100644 --- a/template.go +++ b/server/web/template.go @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "errors" "fmt" "html/template" "io" - "io/ioutil" "net/http" "os" "path/filepath" @@ -27,8 +26,8 @@ import ( "strings" "sync" - "github.com/astaxie/beego/logs" - "github.com/astaxie/beego/utils" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/core/utils" ) var ( @@ -103,7 +102,7 @@ func init() { beegoTplFuncMap["lt"] = lt // < beegoTplFuncMap["ne"] = ne // != - beegoTplFuncMap["urlfor"] = URLFor // build a URL to match a Controller and it's method + beegoTplFuncMap["urlfor"] = URLFor // build an URL to match a Controller and it's method } // AddFuncMap let user to register a func in the template. @@ -163,12 +162,12 @@ func AddTemplateExt(ext string) { } // AddViewPath adds a new path to the supported view paths. -//Can later be used by setting a controller ViewPath to this folder -//will panic if called after beego.Run() +// Can later be used by setting a controller ViewPath to this folder +// will panic if called after beego.Run() func AddViewPath(viewPath string) error { if beeViewPathTemplateLocked { if _, exist := beeViewPathTemplates[viewPath]; exist { - return nil //Ignore if viewpath already exists + return nil // Ignore if viewpath already exists } panic("Can not add new view paths after beego.Run()") } @@ -202,9 +201,7 @@ func BuildTemplate(dir string, files ...string) error { root: dir, files: make(map[string][]string), } - err = Walk(fs, dir, func(path string, f os.FileInfo, err error) error { - return self.visit(path, f, err) - }) + err = Walk(fs, dir, self.visit) if err != nil { fmt.Printf("Walk() returned %v\n", err) return err @@ -252,7 +249,7 @@ func getTplDeep(root string, fs http.FileSystem, file string, parent string, t * panic("can't find template file:" + file) } defer f.Close() - data, err := ioutil.ReadAll(f) + data, err := io.ReadAll(f) if err != nil { return nil, [][]string{}, err } @@ -303,7 +300,7 @@ func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMod if tpl != nil { continue } - //first check filename + // first check filename for _, otherFile := range others { if otherFile == m[1] { var subMods1 [][]string @@ -316,7 +313,7 @@ func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMod break } } - //second check define + // second check define for _, otherFile := range others { var data []byte fileAbsPath := filepath.Join(root, otherFile) @@ -326,7 +323,7 @@ func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMod logs.Trace("template file parse error, not success open file:", err) continue } - data, err = ioutil.ReadAll(f) + data, err = io.ReadAll(f) f.Close() if err != nil { logs.Trace("template file parse error, not success read file:", err) @@ -351,7 +348,6 @@ func _getTemplate(t0 *template.Template, root string, fs http.FileSystem, subMod } } } - } return } @@ -368,14 +364,14 @@ func SetTemplateFSFunc(fnt templateFSFunc) { } // SetViewsPath sets view directory path in beego application. -func SetViewsPath(path string) *App { +func SetViewsPath(path string) *HttpServer { BConfig.WebConfig.ViewsPath = path return BeeApp } // SetStaticPath sets static directory path and proper url pattern in beego application. // if beego.SetStaticPath("static","public"), visit /static/* to load static file in folder "public". -func SetStaticPath(url string, path string) *App { +func SetStaticPath(url string, path string) *HttpServer { if !strings.HasPrefix(url, "/") { url = "/" + url } @@ -387,7 +383,7 @@ func SetStaticPath(url string, path string) *App { } // DelStaticPath removes the static folder setting in this url pattern in beego application. -func DelStaticPath(url string) *App { +func DelStaticPath(url string) *HttpServer { if !strings.HasPrefix(url, "/") { url = "/" + url } @@ -399,7 +395,7 @@ func DelStaticPath(url string) *App { } // AddTemplateEngine add a new templatePreProcessor which support extension -func AddTemplateEngine(extension string, fn templatePreProcessor) *App { +func AddTemplateEngine(extension string, fn templatePreProcessor) *HttpServer { AddTemplateExt(extension) beeTemplateEngines[extension] = fn return BeeApp diff --git a/template_test.go b/server/web/template_test.go similarity index 85% rename from template_test.go rename to server/web/template_test.go index 287faadcf3..2c2585ce50 100644 --- a/template_test.go +++ b/server/web/template_test.go @@ -12,16 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "bytes" - "github.com/astaxie/beego/testdata" - "github.com/elazarl/go-bindata-assetfs" "net/http" "os" "path/filepath" "testing" + + assetfs "github.com/elazarl/go-bindata-assetfs" + "github.com/stretchr/testify/assert" + + "github.com/beego/beego/v2/test" ) var header = `{{define "header"}} @@ -46,17 +49,19 @@ var block = `{{define "block"}} {{end}}` func TestTemplate(t *testing.T) { - dir := "_beeTmp" + tmpDir := os.TempDir() + dir := filepath.Join(tmpDir, "_beeTmp", "TestTemplate") files := []string{ "header.tpl", "index.tpl", "blocks/block.tpl", } - if err := os.MkdirAll(dir, 0777); err != nil { + if err := os.MkdirAll(dir, 0o777); err != nil { t.Fatal(err) } for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0o777) + assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { @@ -95,6 +100,7 @@ var menu = ` ` + var user = ` @@ -107,9 +113,10 @@ var user = ` ` func TestRelativeTemplate(t *testing.T) { - dir := "_beeTmp" + tmpDir := os.TempDir() + dir := filepath.Join(tmpDir, "_beeTmp") - //Just add dir to known viewPaths + // Just add dir to known viewPaths if err := AddViewPath(dir); err != nil { t.Fatal(err) } @@ -118,11 +125,11 @@ func TestRelativeTemplate(t *testing.T) { "easyui/public/menu.tpl", "easyui/rbac/user.tpl", } - if err := os.MkdirAll(dir, 0777); err != nil { + if err := os.MkdirAll(dir, 0o777); err != nil { t.Fatal(err) } for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0o777) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { @@ -218,25 +225,33 @@ var output = ` ` func TestTemplateLayout(t *testing.T) { - dir := "_beeTmp" + tmpDir, err := os.Getwd() + assert.Nil(t, err) + + dir := filepath.Join(tmpDir, "_beeTmp", "TestTemplateLayout") files := []string{ "add.tpl", "layout_blog.tpl", } - if err := os.MkdirAll(dir, 0777); err != nil { + if err := os.MkdirAll(dir, 0o777); err != nil { t.Fatal(err) } + for k, name := range files { - os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777) + dirErr := os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0o777) + assert.Nil(t, dirErr) if f, err := os.Create(filepath.Join(dir, name)); err != nil { t.Fatal(err) } else { if k == 0 { - f.WriteString(add) + _, writeErr := f.WriteString(add) + assert.Nil(t, writeErr) } else if k == 1 { - f.WriteString(layoutBlog) + _, writeErr := f.WriteString(layoutBlog) + assert.Nil(t, writeErr) } - f.Close() + clErr := f.Close() + assert.Nil(t, clErr) } } if err := AddViewPath(dir); err != nil { @@ -247,6 +262,7 @@ func TestTemplateLayout(t *testing.T) { t.Fatalf("should be 2 but got %v", len(beeTemplates)) } out := bytes.NewBufferString("") + if err := beeTemplates["add.tpl"].ExecuteTemplate(out, "add.tpl", map[string]string{"Title": "Hello", "SomeVar": "val"}); err != nil { t.Fatal(err) } @@ -291,7 +307,7 @@ var outputBinData = ` func TestFsBinData(t *testing.T) { SetTemplateFSFunc(func() http.FileSystem { - return TestingFileSystem{&assetfs.AssetFS{Asset: testdata.Asset, AssetDir: testdata.AssetDir, AssetInfo: testdata.AssetInfo}} + return TestingFileSystem{&assetfs.AssetFS{Asset: test.Asset, AssetDir: test.AssetDir, AssetInfo: test.AssetInfo}} }) dir := "views" if err := AddViewPath("views"); err != nil { diff --git a/templatefunc.go b/server/web/templatefunc.go similarity index 74% rename from templatefunc.go rename to server/web/templatefunc.go index d62442aec5..f7cce06a3f 100644 --- a/templatefunc.go +++ b/server/web/templatefunc.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "errors" @@ -25,13 +25,9 @@ import ( "strconv" "strings" "time" -) -const ( - formatTime = "15:04:05" - formatDate = "2006-01-02" - formatDateTime = "2006-01-02 15:04:05" - formatDateTimeT = "2006-01-02T15:04:05" + "github.com/beego/beego/v2/core/logs" + "github.com/beego/beego/v2/server/web/context" ) // Substr returns the substr from start to length. @@ -54,15 +50,14 @@ func Substr(s string, start, length int) string { // HTML2str returns escaping text convert from html. func HTML2str(html string) string { - re := regexp.MustCompile(`\<[\S\s]+?\>`) html = re.ReplaceAllStringFunc(html, strings.ToLower) - //remove STYLE + // remove STYLE re = regexp.MustCompile(`\`) html = re.ReplaceAllString(html, "") - //remove SCRIPT + // remove SCRIPT re = regexp.MustCompile(`\`) html = re.ReplaceAllString(html, "") @@ -85,7 +80,7 @@ func DateFormat(t time.Time, layout string) (datestring string) { var datePatterns = []string{ // year "Y", "2006", // A full numeric representation of a year, 4 digits Examples: 1999 or 2003 - "y", "06", //A two digit representation of a year Examples: 99 or 03 + "y", "06", // A two digit representation of a year Examples: 99 or 03 // month "m", "01", // Numeric representation of a month, with leading zeros 01 through 12 @@ -160,7 +155,7 @@ func NotNil(a interface{}) (isNil bool) { func GetConfig(returnType, key string, defaultVal interface{}) (value interface{}, err error) { switch returnType { case "String": - value = AppConfig.String(key) + value, err = AppConfig.String(key) case "Bool": value, err = AppConfig.Bool(key) case "Int": @@ -201,7 +196,7 @@ func Str2html(raw string) template.HTML { // Htmlquote returns quoted html string. func Htmlquote(text string) string { - //HTML编码为实体符号 + // HTML编码为实体符号 /* Encodes `text` for raw use in HTML. >>> htmlquote("<'&\\">") @@ -220,7 +215,7 @@ func Htmlquote(text string) string { // Htmlunquote returns unquoted html string. func Htmlunquote(text string) string { - //实体符号解释为HTML + // 实体符号解释为HTML /* Decodes `text` that's HTML quoted. >>> htmlunquote('<'&">') @@ -233,29 +228,27 @@ func Htmlunquote(text string) string { } // URLFor returns url string with another registered controller handler with params. -// usage: // -// URLFor(".index") -// print URLFor("index") -// router /login -// print URLFor("login") -// print URLFor("login", "next","/"") -// router /profile/:username -// print UrlFor("profile", ":username","John Doe") -// result: -// / -// /login -// /login?next=/ -// /user/John%20Doe +// usage: // -// more detail http://beego.me/docs/mvc/controller/urlbuilding.md +// URLFor(".index") +// print URLFor("index") +// router /login +// print URLFor("login") +// print URLFor("login", "next","/"") +// router /profile/:username +// print UrlFor("profile", ":username","John Doe") +// result: +// / +// /login +// /login?next=/ +// /user/John%20Doe func URLFor(endpoint string, values ...interface{}) string { return BeeApp.Handlers.URLFor(endpoint, values...) } // AssetsJs returns script tag with src string. func AssetsJs(text string) template.HTML { - text = "" return template.HTML(text) @@ -263,167 +256,16 @@ func AssetsJs(text string) template.HTML { // AssetsCSS returns stylesheet link tag with src string. func AssetsCSS(text string) template.HTML { - text = "" return template.HTML(text) } -// ParseForm will parse form values to struct via tag. -// Support for anonymous struct. -func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) error { - for i := 0; i < objT.NumField(); i++ { - fieldV := objV.Field(i) - if !fieldV.CanSet() { - continue - } - - fieldT := objT.Field(i) - if fieldT.Anonymous && fieldT.Type.Kind() == reflect.Struct { - err := parseFormToStruct(form, fieldT.Type, fieldV) - if err != nil { - return err - } - continue - } - - tags := strings.Split(fieldT.Tag.Get("form"), ",") - var tag string - if len(tags) == 0 || len(tags[0]) == 0 { - tag = fieldT.Name - } else if tags[0] == "-" { - continue - } else { - tag = tags[0] - } - - formValues := form[tag] - var value string - if len(formValues) == 0 { - defaultValue := fieldT.Tag.Get("default") - if defaultValue != "" { - value = defaultValue - } else { - continue - } - } - if len(formValues) == 1 { - value = formValues[0] - if value == "" { - continue - } - } - - switch fieldT.Type.Kind() { - case reflect.Bool: - if strings.ToLower(value) == "on" || strings.ToLower(value) == "1" || strings.ToLower(value) == "yes" { - fieldV.SetBool(true) - continue - } - if strings.ToLower(value) == "off" || strings.ToLower(value) == "0" || strings.ToLower(value) == "no" { - fieldV.SetBool(false) - continue - } - b, err := strconv.ParseBool(value) - if err != nil { - return err - } - fieldV.SetBool(b) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - x, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return err - } - fieldV.SetInt(x) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - x, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return err - } - fieldV.SetUint(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(value, 64) - if err != nil { - return err - } - fieldV.SetFloat(x) - case reflect.Interface: - fieldV.Set(reflect.ValueOf(value)) - case reflect.String: - fieldV.SetString(value) - case reflect.Struct: - switch fieldT.Type.String() { - case "time.Time": - var ( - t time.Time - err error - ) - if len(value) >= 25 { - value = value[:25] - t, err = time.ParseInLocation(time.RFC3339, value, time.Local) - } else if len(value) >= 19 { - if strings.Contains(value, "T") { - value = value[:19] - t, err = time.ParseInLocation(formatDateTimeT, value, time.Local) - } else { - value = value[:19] - t, err = time.ParseInLocation(formatDateTime, value, time.Local) - } - } else if len(value) >= 10 { - if len(value) > 10 { - value = value[:10] - } - t, err = time.ParseInLocation(formatDate, value, time.Local) - } else if len(value) >= 8 { - if len(value) > 8 { - value = value[:8] - } - t, err = time.ParseInLocation(formatTime, value, time.Local) - } - if err != nil { - return err - } - fieldV.Set(reflect.ValueOf(t)) - } - case reflect.Slice: - if fieldT.Type == sliceOfInts { - formVals := form[tag] - fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf(int(1))), len(formVals), len(formVals))) - for i := 0; i < len(formVals); i++ { - val, err := strconv.Atoi(formVals[i]) - if err != nil { - return err - } - fieldV.Index(i).SetInt(int64(val)) - } - } else if fieldT.Type == sliceOfStrings { - formVals := form[tag] - fieldV.Set(reflect.MakeSlice(reflect.SliceOf(reflect.TypeOf("")), len(formVals), len(formVals))) - for i := 0; i < len(formVals); i++ { - fieldV.Index(i).SetString(formVals[i]) - } - } - } - } - return nil -} - // ParseForm will parse form values to struct via tag. func ParseForm(form url.Values, obj interface{}) error { - objT := reflect.TypeOf(obj) - objV := reflect.ValueOf(obj) - if !isStructPtr(objT) { - return fmt.Errorf("%v must be a struct pointer", obj) - } - objT = objT.Elem() - objV = objV.Elem() - - return parseFormToStruct(form, objT, objV) + return context.ParseForm(form, obj) } -var sliceOfInts = reflect.TypeOf([]int(nil)) -var sliceOfStrings = reflect.TypeOf([]string(nil)) - var unKind = map[reflect.Kind]bool{ reflect.Uintptr: true, reflect.Complex64: true, @@ -440,10 +282,11 @@ var unKind = map[reflect.Kind]bool{ // RenderForm will render object to form html. // obj must be a struct pointer. +// nolint func RenderForm(obj interface{}) template.HTML { objT := reflect.TypeOf(obj) objV := reflect.ValueOf(obj) - if !isStructPtr(objT) { + if objT.Kind() != reflect.Ptr || objT.Elem().Kind() != reflect.Struct { return template.HTML("") } objT = objT.Elem() @@ -468,7 +311,8 @@ func RenderForm(obj interface{}) template.HTML { return template.HTML(strings.Join(raw, "
")) } -// renderFormField returns a string containing HTML of a single form field. +// renderFormField returns a string containing HTML of a single form field. In case of select fType, it will retrun +// select tag with options. Value for select fType must be comma separated string which are use are func renderFormField(label, name, fType string, value interface{}, id string, class string, required bool) string { if id != "" { id = " id=\"" + id + "\"" @@ -487,6 +331,25 @@ func renderFormField(label, name, fType string, value interface{}, id string, cl return fmt.Sprintf(`%v`, label, id, class, name, fType, value, requiredString) } + if fType == "select" { + valueStr, ok := value.(string) + if !ok { + logs.Error("for select value must comma separated string that are the options for select") + return "" + } + + var selectBuilder strings.Builder + selectBuilder.WriteString(fmt.Sprintf(`%v
`, label, id, class, name)) + + for _, option := range strings.Split(valueStr, ",") { + selectBuilder.WriteString(fmt.Sprintf(`
`, option, option)) + } + + selectBuilder.WriteString(``) + + return selectBuilder.String() + } + return fmt.Sprintf(`%v<%v%v%v name="%v"%v>%v`, label, fType, id, class, name, requiredString, value, fType) } @@ -502,7 +365,7 @@ func isValidForInput(fType string) bool { } // parseFormTag takes the stuct-tag of a StructField and parses the `form` value. -// returned are the form label, name-property, type and wether the field should be ignored. +// returned are the form label, name-property, type and whether the field should be ignored. func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id string, class string, ignored bool, required bool) { tags := strings.Split(fieldT.Tag.Get("form"), ",") label = fieldT.Name + ": " @@ -548,10 +411,6 @@ func parseFormTag(fieldT reflect.StructField) (label, name, fType string, id str return } -func isStructPtr(t reflect.Type) bool { - return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct -} - // go1.2 added template funcs. begin var ( errBadComparisonType = errors.New("invalid type for comparison") @@ -605,10 +464,23 @@ func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { if err != nil { return false, err } + truth := false if k1 != k2 { - return false, errBadComparison + // Special case: Can compare integer values regardless of type's sign. + switch { + case k1 == intKind && k2 == uintKind: + truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint() + case k1 == uintKind && k2 == intKind: + truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int()) + default: + return false, errBadComparison + } + if truth { + return true, nil + } else { + return false, nil + } } - truth := false switch k1 { case boolKind: truth = v1.Bool() == v2.Bool() @@ -651,23 +523,32 @@ func lt(arg1, arg2 interface{}) (bool, error) { if err != nil { return false, err } - if k1 != k2 { - return false, errBadComparison - } truth := false - switch k1 { - case boolKind, complexKind: - return false, errBadComparisonType - case floatKind: - truth = v1.Float() < v2.Float() - case intKind: - truth = v1.Int() < v2.Int() - case stringKind: - truth = v1.String() < v2.String() - case uintKind: - truth = v1.Uint() < v2.Uint() - default: - panic("invalid kind") + if k1 != k2 { + // Special case: Can compare integer values regardless of type's sign. + switch { + case k1 == intKind && k2 == uintKind: + truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint() + case k1 == uintKind && k2 == intKind: + truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int()) + default: + return false, errBadComparison + } + } else { + switch k1 { + case boolKind, complexKind: + return false, errBadComparisonType + case floatKind: + truth = v1.Float() < v2.Float() + case intKind: + truth = v1.Int() < v2.Int() + case stringKind: + truth = v1.String() < v2.String() + case uintKind: + truth = v1.Uint() < v2.Uint() + default: + return false, errBadComparisonType + } } return truth, nil } @@ -704,12 +585,13 @@ func ge(arg1, arg2 interface{}) (bool, error) { // MapGet getting value from map by keys // usage: -// Data["m"] = M{ -// "a": 1, -// "1": map[string]float64{ -// "c": 4, -// }, -// } +// +// Data["m"] = M{ +// "a": 1, +// "1": map[string]float64{ +// "c": 4, +// }, +// } // // {{ map_get m "a" }} // return 1 // {{ map_get m 1 "c" }} // return 4 diff --git a/templatefunc_test.go b/server/web/templatefunc_test.go similarity index 63% rename from templatefunc_test.go rename to server/web/templatefunc_test.go index b4c19c2ef7..241163782c 100644 --- a/templatefunc_test.go +++ b/server/web/templatefunc_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "html/template" @@ -205,13 +205,13 @@ func TestRenderForm(t *testing.T) { ID int `form:"-"` Name interface{} `form:"username"` Age int `form:"age,text,年龄:"` - Sex string + Sex string `form:"sex,select"` Email []string Intro string `form:",textarea"` Ignored string `form:"-"` } - u := user{Name: "test", Intro: "Some Text"} + u := user{Name: "test", Intro: "Some Text", Sex: "Male,Female"} output := RenderForm(u) if output != template.HTML("") { t.Errorf("output should be empty but got %v", output) @@ -220,7 +220,10 @@ func TestRenderForm(t *testing.T) { result := template.HTML( `Name:
` + `年龄:
` + - `Sex:
` + + `Sex:
` + `Intro: `) if output != result { t.Errorf("output should equal `%v` but got `%v`", result, output) @@ -299,7 +302,6 @@ func TestParseFormTag(t *testing.T) { if !(name == "name" && !required) { t.Errorf("Form Tag containing only name and not required was not correctly parsed.") } - } func TestMapGet(t *testing.T) { @@ -378,3 +380,206 @@ func TestMapGet(t *testing.T) { t.Errorf("Error happens %v", err) } } + +func Test_eq(t *testing.T) { + tests := []struct { + a interface{} + b interface{} + result bool + }{ + {uint8(1), int(1), true}, + {uint8(3), int(1), false}, + {uint16(1), int(1), true}, + {uint16(3), int(1), false}, + {uint32(1), int(1), true}, + {uint32(3), int(1), false}, + {uint(1), int(1), true}, + {uint(3), int(1), false}, + {uint64(1), int(1), true}, + {uint64(3), int(1), false}, + {int8(-1), uint(1), false}, + {int16(-2), uint(1), false}, + {int32(-3), uint(1), false}, + {int64(-4), uint(1), false}, + {int8(1), uint(1), true}, + {int16(1), uint(1), true}, + {int32(1), uint(1), true}, + {int64(1), uint(1), true}, + {int8(-1), uint8(1), false}, + {int16(-2), uint8(1), false}, + {int32(-3), uint8(1), false}, + {int64(-4), uint8(1), false}, + {int8(1), uint8(1), true}, + {int16(1), uint8(1), true}, + {int32(1), uint8(1), true}, + {int64(1), uint8(1), true}, + {int8(-1), uint16(1), false}, + {int16(-2), uint16(1), false}, + {int32(-3), uint16(1), false}, + {int64(-4), uint16(1), false}, + {int8(1), uint16(1), true}, + {int16(1), uint16(1), true}, + {int32(1), uint16(1), true}, + {int64(1), uint16(1), true}, + {int8(-1), uint32(1), false}, + {int16(-2), uint32(1), false}, + {int32(-3), uint32(1), false}, + {int64(-4), uint32(1), false}, + {int8(1), uint32(1), true}, + {int16(1), uint32(1), true}, + {int32(1), uint32(1), true}, + {int64(1), uint32(1), true}, + {int8(-1), uint64(1), false}, + {int16(-2), uint64(1), false}, + {int32(-3), uint64(1), false}, + {int64(-4), uint64(1), false}, + {int8(1), uint64(1), true}, + {int16(1), uint64(1), true}, + {int32(1), uint64(1), true}, + {int64(1), uint64(1), true}, + } + + for _, test := range tests { + if res, err := eq(test.a, test.b); err != nil { + if res != test.result { + t.Errorf("a:%v(%T) equals b:%v(%T) should be %v", test.a, test.a, test.b, test.b, test.result) + } + } + } +} + +func Test_lt(t *testing.T) { + tests := []struct { + a interface{} + b interface{} + result bool + }{ + {uint8(1), int(3), true}, + {uint8(1), int(1), false}, + {uint8(3), int(1), false}, + {uint16(1), int(3), true}, + {uint16(1), int(1), false}, + {uint16(3), int(1), false}, + {uint32(1), int(3), true}, + {uint32(1), int(1), false}, + {uint32(3), int(1), false}, + {uint(1), int(3), true}, + {uint(1), int(1), false}, + {uint(3), int(1), false}, + {uint64(1), int(3), true}, + {uint64(1), int(1), false}, + {uint64(3), int(1), false}, + {int(-1), int(1), true}, + {int(1), int(1), false}, + {int(1), int(3), true}, + {int(3), int(1), false}, + {int8(-1), uint(1), true}, + {int8(1), uint(1), false}, + {int8(1), uint(3), true}, + {int8(3), uint(1), false}, + {int16(-1), uint(1), true}, + {int16(1), uint(1), false}, + {int16(1), uint(3), true}, + {int16(3), uint(1), false}, + {int32(-1), uint(1), true}, + {int32(1), uint(1), false}, + {int32(1), uint(3), true}, + {int32(3), uint(1), false}, + {int64(-1), uint(1), true}, + {int64(1), uint(1), false}, + {int64(1), uint(3), true}, + {int64(3), uint(1), false}, + {int(-1), uint(1), true}, + {int(1), uint(1), false}, + {int(1), uint(3), true}, + {int(3), uint(1), false}, + {int8(-1), uint8(1), true}, + {int8(1), uint8(1), false}, + {int8(1), uint8(3), true}, + {int8(3), uint8(1), false}, + {int16(-1), uint8(1), true}, + {int16(1), uint8(1), false}, + {int16(1), uint8(3), true}, + {int16(3), uint8(1), false}, + {int32(-1), uint8(1), true}, + {int32(1), uint8(1), false}, + {int32(1), uint8(3), true}, + {int32(3), uint8(1), false}, + {int64(-1), uint8(1), true}, + {int64(1), uint8(1), false}, + {int64(1), uint8(3), true}, + {int64(3), uint8(1), false}, + {int(-1), uint8(1), true}, + {int(1), uint8(1), false}, + {int(1), uint8(3), true}, + {int(3), uint8(1), false}, + {int8(-1), uint16(1), true}, + {int8(1), uint16(1), false}, + {int8(1), uint16(3), true}, + {int8(3), uint16(1), false}, + {int16(-1), uint16(1), true}, + {int16(1), uint16(1), false}, + {int16(1), uint16(3), true}, + {int16(3), uint16(1), false}, + {int32(-1), uint16(1), true}, + {int32(1), uint16(1), false}, + {int32(1), uint16(3), true}, + {int32(3), uint16(1), false}, + {int64(-1), uint16(1), true}, + {int64(1), uint16(1), false}, + {int64(1), uint16(3), true}, + {int64(3), uint16(1), false}, + {int(-1), uint16(1), true}, + {int(1), uint16(1), false}, + {int(1), uint16(3), true}, + {int(3), uint16(1), false}, + {int8(-1), uint32(1), true}, + {int8(1), uint32(1), false}, + {int8(1), uint32(3), true}, + {int8(3), uint32(1), false}, + {int16(-1), uint32(1), true}, + {int16(1), uint32(1), false}, + {int16(1), uint32(3), true}, + {int16(3), uint32(1), false}, + {int32(-1), uint32(1), true}, + {int32(1), uint32(1), false}, + {int32(1), uint32(3), true}, + {int32(3), uint32(1), false}, + {int64(-1), uint32(1), true}, + {int64(1), uint32(1), false}, + {int64(1), uint32(3), true}, + {int64(3), uint32(1), false}, + {int(-1), uint32(1), true}, + {int(1), uint32(1), false}, + {int(1), uint32(3), true}, + {int(3), uint32(1), false}, + {int8(-1), uint64(1), true}, + {int8(1), uint64(1), false}, + {int8(1), uint64(3), true}, + {int8(3), uint64(1), false}, + {int16(-1), uint64(1), true}, + {int16(1), uint64(1), false}, + {int16(1), uint64(3), true}, + {int16(3), uint64(1), false}, + {int32(-1), uint64(1), true}, + {int32(1), uint64(1), false}, + {int32(1), uint64(3), true}, + {int32(3), uint64(1), false}, + {int64(-1), uint64(1), true}, + {int64(1), uint64(1), false}, + {int64(1), uint64(3), true}, + {int64(3), uint64(1), false}, + {int(-1), uint64(1), true}, + {int(1), uint64(1), false}, + {int(1), uint64(3), true}, + {int(3), uint64(1), false}, + } + + for _, test := range tests { + if res, err := lt(test.a, test.b); err != nil { + if res != test.result { + t.Errorf("a:%v(%T) lt b:%v(%T) should be %v", test.a, test.a, test.b, test.b, test.result) + } + } + } +} diff --git a/server/web/test/static/file.txt b/server/web/test/static/file.txt new file mode 100644 index 0000000000..7083a77206 --- /dev/null +++ b/server/web/test/static/file.txt @@ -0,0 +1,45 @@ + + + \ No newline at end of file diff --git a/tree.go b/server/web/tree.go similarity index 91% rename from tree.go rename to server/web/tree.go index 9e53003bc0..adcb885377 100644 --- a/tree.go +++ b/server/web/tree.go @@ -12,33 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "path" "regexp" "strings" - "github.com/astaxie/beego/context" - "github.com/astaxie/beego/utils" + "github.com/beego/beego/v2/core/utils" + "github.com/beego/beego/v2/server/web/context" ) -var ( - allowSuffixExt = []string{".json", ".xml", ".html"} -) +var allowSuffixExt = []string{".json", ".xml", ".html"} // Tree has three elements: FixRouter/wildcard/leaves // fixRouter stores Fixed Router // wildcard stores params // leaves store the endpoint information type Tree struct { - //prefix set for static router + // prefix set for static router prefix string - //search fix route first + // search fix route first fixrouters []*Tree - //if set, failure to match fixrouters search then search wildcard + // if set, failure to match fixrouters search then search wildcard wildcard *Tree - //if set, failure to match wildcard search + // if set, failure to match wildcard search leaves []*leafInfo } @@ -68,13 +66,13 @@ func (t *Tree) addtree(segments []string, tree *Tree, wildcards []string, reg st filterTreeWithPrefix(tree, wildcards, reg) } } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr + // Rule: /login/*/access match /login/2009/11/access + // if already has *, and when loop the access, should as a regexpStr if !iswild && utils.InSlice(":splat", wildcards) { iswild = true regexpStr = seg } - //Rule: /user/:id/* + // Rule: /user/:id/* if seg == "*" && len(wildcards) > 0 && reg == "" { regexpStr = "(.+)" } @@ -209,9 +207,9 @@ func (t *Tree) AddRouter(pattern string, runObject interface{}) { func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, reg string) { if len(segments) == 0 { if reg != "" { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}) + t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards, regexps: regexp.MustCompile("^" + reg + "$")}}, t.leaves...) } else { - t.leaves = append(t.leaves, &leafInfo{runObject: route, wildcards: wildcards}) + t.leaves = append([]*leafInfo{{runObject: route, wildcards: wildcards}}, t.leaves...) } } else { seg := segments[0] @@ -221,13 +219,13 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, t.addseg(segments[1:], route, wildcards, reg) params = params[1:] } - //Rule: /login/*/access match /login/2009/11/access - //if already has *, and when loop the access, should as a regexpStr + // Rule: /login/*/access match /login/2009/11/access + // if already has *, and when loop the access, should as a regexpStr if !iswild && utils.InSlice(":splat", wildcards) { iswild = true regexpStr = seg } - //Rule: /user/:id/* + // Rule: /user/:id/* if seg == "*" && len(wildcards) > 0 && reg == "" { regexpStr = "(.+)" } @@ -284,7 +282,9 @@ func (t *Tree) addseg(segments []string, route interface{}, wildcards []string, // Match router to runObject & params func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{}) { - if len(pattern) == 0 || pattern[0] != '/' { + // fix issue 4961, deal with "./ ../ //" + pattern = path.Clean(pattern) + if pattern == "" || pattern[0] != '/' { return nil } w := make([]string, 0, 20) @@ -293,13 +293,14 @@ func (t *Tree) Match(pattern string, ctx *context.Context) (runObject interface{ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string, ctx *context.Context) (runObject interface{}) { if len(pattern) > 0 { - i := 0 - for ; i < len(pattern) && pattern[i] == '/'; i++ { + i, l := 0, len(pattern) + for i < l && pattern[i] == '/' { + i++ } pattern = pattern[i:] } // Handle leaf nodes: - if len(pattern) == 0 { + if pattern == "" { for _, l := range t.leaves { if ok := l.match(treePattern, wildcardValues, ctx); ok { return l.runObject @@ -316,7 +317,8 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string } var seg string i, l := 0, len(pattern) - for ; i < l && pattern[i] != '/'; i++ { + for i < l && pattern[i] != '/' { + i++ } if i == 0 { seg = pattern @@ -327,7 +329,7 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string } for _, subTree := range t.fixrouters { if subTree.prefix == seg { - if len(pattern) != 0 && pattern[0] == '/' { + if pattern != "" && pattern[0] == '/' { treePattern = pattern[1:] } else { treePattern = pattern @@ -341,7 +343,8 @@ func (t *Tree) match(treePattern string, pattern string, wildcardValues []string if runObject == nil && len(t.fixrouters) > 0 { // Filter the .json .xml .html extension for _, str := range allowSuffixExt { - if strings.HasSuffix(seg, str) { + // pattern == "" avoid cases: /aaa.html/aaa.html could access /aaa/:bbb + if strings.HasSuffix(seg, str) && pattern == "" { for _, subTree := range t.fixrouters { if subTree.prefix == seg[:len(seg)-len(str)] { runObject = subTree.match(treePattern, pattern, wildcardValues, ctx) @@ -392,7 +395,7 @@ type leafInfo struct { } func (leaf *leafInfo) match(treePattern string, wildcardValues []string, ctx *context.Context) (ok bool) { - //fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) + // fmt.Println("Leaf:", wildcardValues, leaf.wildcards, leaf.regexps) if leaf.regexps == nil { if len(wildcardValues) == 0 && len(leaf.wildcards) == 0 { // static path return true @@ -499,7 +502,7 @@ func splitSegment(key string) (bool, []string, string) { continue } if start { - //:id:int and :name:string + // :id:int and :name:string if v == ':' { if len(key) >= i+4 { if key[i+1:i+4] == "int" { diff --git a/server/web/tree_test.go b/server/web/tree_test.go new file mode 100644 index 0000000000..af0c888ef3 --- /dev/null +++ b/server/web/tree_test.go @@ -0,0 +1,376 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "strings" + "testing" + "time" + + "github.com/beego/beego/v2/server/web/context" +) + +type testInfo struct { + pattern string + requestUrl string + params map[string]string + shouldMatchOrNot bool +} + +var routers []testInfo + +func matchTestInfo(pattern, url string, params map[string]string) testInfo { + return testInfo{ + pattern: pattern, + requestUrl: url, + params: params, + shouldMatchOrNot: true, + } +} + +func notMatchTestInfo(pattern, url string) testInfo { + return testInfo{ + pattern: pattern, + requestUrl: url, + params: nil, + shouldMatchOrNot: false, + } +} + +func init() { + const ( + abcHTML = "/suffix/abc.html" + abcSuffix = "/abc/suffix/*" + ) + + routers = []testInfo{ + // match example + matchTestInfo("/topic/?:auth:int", "/topic", nil), + matchTestInfo("/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}), + matchTestInfo("/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}), + matchTestInfo("/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}), + matchTestInfo("/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}), + matchTestInfo("/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}), + matchTestInfo("/:id", "/123", map[string]string{":id": "123"}), + matchTestInfo("/hello/?:id", "/hello", map[string]string{":id": ""}), + matchTestInfo("/", "/", nil), + matchTestInfo("/customer/login", "/customer/login", nil), + matchTestInfo("/customer/login", "/customer/login.json", map[string]string{":ext": "json"}), + // This case need to be modified when fix issue 4961, "//" will be replaced with "/" and last "/" will be deleted before route. + matchTestInfo("/*", "/http://customer/123/", map[string]string{":splat": "http:/customer/123"}), + matchTestInfo("/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}), + matchTestInfo("/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}), + matchTestInfo("/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}), + matchTestInfo("/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}), + matchTestInfo("/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}), + matchTestInfo("/thumbnail/:size/uploads/*", "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}), + matchTestInfo("/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}), + matchTestInfo("/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}), + matchTestInfo("/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}), + matchTestInfo("/dl/:width:int/:height:int/*.*", "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}), + matchTestInfo("/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}), + matchTestInfo("/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}), + matchTestInfo("/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}), + matchTestInfo("/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}), + matchTestInfo("/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}), + matchTestInfo("/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}), + matchTestInfo("/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}), + matchTestInfo("/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}), + matchTestInfo("/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}), + matchTestInfo("/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}), + matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}), + matchTestInfo("/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}), + matchTestInfo("/?:year/?:month/?:day", "/2020/11/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"}), + matchTestInfo("/?:year/?:month/?:day", "/2020/11", map[string]string{":year": "2020", ":month": "11"}), + matchTestInfo("/?:year", "/2020", map[string]string{":year": "2020"}), + matchTestInfo("/?:year([0-9]+)/?:month([0-9]+)/mid/?:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"}), + matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/mid/10", map[string]string{":year": "2020", ":day": "10"}), + matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/2020/11/mid", map[string]string{":year": "2020", ":month": "11"}), + matchTestInfo("/?:year/?:month/mid/?:day/?:hour", "/mid/10/24", map[string]string{":day": "10", ":hour": "24"}), + matchTestInfo("/?:year([0-9]+)/:month([0-9]+)/mid/:day([0-9]+)/?:hour([0-9]+)", "/2020/11/mid/10/24", map[string]string{":year": "2020", ":month": "11", ":day": "10", ":hour": "24"}), + matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10/24", map[string]string{":month": "11", ":day": "10"}), + matchTestInfo("/?:year/:month/mid/:day/?:hour", "/2020/11/mid/10", map[string]string{":year": "2020", ":month": "11", ":day": "10"}), + matchTestInfo("/?:year/:month/mid/:day/?:hour", "/11/mid/10", map[string]string{":month": "11", ":day": "10"}), + + // not match example + // https://github.com/beego/beego/v2/issues/3865 + notMatchTestInfo("/read_:id:int\\.htm", "/read_222htm"), + notMatchTestInfo("/read_:id:int\\.htm", "/read_222_htm"), + notMatchTestInfo("/read_:id:int\\.htm", " /read_262shtm"), + + // test .html, .json not suffix + notMatchTestInfo(abcHTML, "/suffix.html/abc"), + matchTestInfo("/suffix/abc", abcHTML, nil), + matchTestInfo("/suffix/*", abcHTML, nil), + notMatchTestInfo("/suffix/*", "/suffix.html/a"), + notMatchTestInfo(abcSuffix, "/abc/suffix.html/a"), + matchTestInfo(abcSuffix, "/abc/suffix/a", nil), + notMatchTestInfo(abcSuffix, "/abc.j/suffix/a"), + // test for fix of issue 4946 + notMatchTestInfo("/suffix/:name", "/suffix.html/suffix.html"), + matchTestInfo("/suffix/:id/name", "/suffix/1234/name.html", map[string]string{":id": "1234", ":ext": "html"}), + // test for fix of issue 4961,path.join() lead to cross directory risk + matchTestInfo("/book1/:name/fixPath1/*.*", "/book1/name1/fixPath1/mybook/../mybook2.txt", map[string]string{":name": "name1", ":path": "mybook2"}), + notMatchTestInfo("/book1/:name/fixPath1/*.*", "/book1/name1/fixPath1/mybook/../../mybook2.txt"), + notMatchTestInfo("/book1/:name/fixPath1/*.*", "/book1/../fixPath1/mybook/../././////evil.txt"), + notMatchTestInfo("/book1/:name/fixPath1/*.*", "/book1/./fixPath1/mybook/../././////evil.txt"), + notMatchTestInfo("/book2/:type:string/fixPath1/:name", "/book2/type1/fixPath1/name1/../../././////evilType/evilName"), + notMatchTestInfo("/book2/:type:string/fixPath1/:name", "/book2/type1/fixPath1/name1/../../././////evilType/evilName"), + notMatchTestInfo("/book2/:type:string/fixPath1/:name", "/book2/type1/fixPath1/name1/../../././////evilType/evilName"), + } +} + +func TestTreeRouters(t *testing.T) { + for _, r := range routers { + shouldMatch := r.shouldMatchOrNot + tr := NewTree() + tr.AddRouter(r.pattern, "astaxie") + ctx := context.NewContext() + obj := tr.Match(r.requestUrl, ctx) + if !shouldMatch { + if obj != nil { + t.Fatal("pattern:", r.pattern, ", should not match", r.requestUrl) + } else { + continue + } + } + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("pattern:", r.pattern+", can't match obj, Expect ", r.requestUrl) + } + if r.params != nil { + for k, v := range r.params { + if vv := ctx.Input.Param(k); vv != v { + t.Fatal("The Rule: " + r.pattern + "\nThe RequestURL:" + r.requestUrl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) + } else if vv == "" && v != "" { + t.Fatal(r.pattern + " " + r.requestUrl + " get param empty:" + k) + } + } + } + } + time.Sleep(time.Second) +} + +func TestStaticPath(t *testing.T) { + tr := NewTree() + tr.AddRouter("/topic/:id", "wildcard") + tr.AddRouter("/topic", "static") + ctx := context.NewContext() + obj := tr.Match("/topic", ctx) + if obj == nil || obj.(string) != "static" { + t.Fatal("/topic is a static route") + } + obj = tr.Match("/topic/1", ctx) + if obj == nil || obj.(string) != "wildcard" { + t.Fatal("/topic/1 is a wildcard route") + } +} + +func TestAddTree(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t1 := NewTree() + t1.AddTree("/v1/zl", tr) + ctx := context.NewContext() + obj := t1.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" { + t.Fatal("get :id param error") + } + ctx.Input.Reset(ctx) + obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { + t.Fatal("get :sd :id :page param error") + } + + t2 := NewTree() + t2.AddTree("/v1/:shopid", tr) + ctx.Input.Reset(ctx) + obj = t2.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :id :shopid param error") + } + ctx.Input.Reset(ctx) + obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get :shopid param error") + } + if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { + t.Fatal("get :sd :id :page :shopid param error") + } +} + +func TestAddTree2(t *testing.T) { + tr := NewTree() + tr.AddRouter("/shop/:id/account", "astaxie") + tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") + t3 := NewTree() + t3.AddTree("/:version(v1|v2)/:prefix", tr) + ctx := context.NewContext() + obj := t3.Match("/v1/zl/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { + t.Fatal("get :id :prefix :version param error") + } +} + +func TestAddTree3(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/account", "astaxie") + t3 := NewTree() + t3.AddTree("/table/:num", tr) + ctx := context.NewContext() + obj := t3.Match("/table/123/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/shop/:sd/account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { + t.Fatal("get :num :sd param error") + } + ctx.Input.Reset(ctx) + obj = t3.Match("/table/123/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/table/:num/create can't get obj ") + } +} + +func TestAddTree4(t *testing.T) { + tr := NewTree() + tr.AddRouter("/create", "astaxie") + tr.AddRouter("/shop/:sd/:account", "astaxie") + t4 := NewTree() + t4.AddTree("/:info:int/:num/:id", tr) + ctx := context.NewContext() + obj := t4.Match("/12/123/456/shop/123/account", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") + } + if ctx.Input.ParamsLen() == 0 { + t.Fatal("get param error") + } + if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || + ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || + ctx.Input.Param(":account") != "account" { + t.Fatal("get :info :num :id :sd :account param error") + } + ctx.Input.Reset(ctx) + obj = t4.Match("/12/123/456/create", ctx) + if obj == nil || obj.(string) != "astaxie" { + t.Fatal("/:info:int/:num/:id/create can't get obj ") + } +} + +// Test for issue #1595 +func TestAddTree5(t *testing.T) { + tr := NewTree() + tr.AddRouter("/v1/shop/:id", "shopdetail") + tr.AddRouter("/v1/shop/", "shophome") + ctx := context.NewContext() + obj := tr.Match("/v1/shop/", ctx) + if obj == nil || obj.(string) != "shophome" { + t.Fatal("url /v1/shop/ need match router /v1/shop/ ") + } +} + +func TestSplitPath(t *testing.T) { + a := splitPath("") + if len(a) != 0 { + t.Fatal("/ should retrun []") + } + a = splitPath("/") + if len(a) != 0 { + t.Fatal("/ should retrun []") + } + a = splitPath("/admin") + if len(a) != 1 || a[0] != "admin" { + t.Fatal("/admin should retrun [admin]") + } + a = splitPath("/admin/") + if len(a) != 1 || a[0] != "admin" { + t.Fatal("/admin/ should retrun [admin]") + } + a = splitPath("/admin/users") + if len(a) != 2 || a[0] != "admin" || a[1] != "users" { + t.Fatal("/admin should retrun [admin users]") + } + a = splitPath("/admin/:id:int") + if len(a) != 2 || a[0] != "admin" || a[1] != ":id:int" { + t.Fatal("/admin should retrun [admin :id:int]") + } +} + +func TestSplitSegment(t *testing.T) { + items := map[string]struct { + isReg bool + params []string + regStr string + }{ + "admin": {false, nil, ""}, + "*": {true, []string{":splat"}, ""}, + "*.*": {true, []string{".", ":path", ":ext"}, ""}, + ":id": {true, []string{":id"}, ""}, + "?:id": {true, []string{":", ":id"}, ""}, + ":id:int": {true, []string{":id"}, "([0-9]+)"}, + ":name:string": {true, []string{":name"}, `([\w]+)`}, + ":id([0-9]+)": {true, []string{":id"}, `([0-9]+)`}, + ":id([0-9]+)_:name": {true, []string{":id", ":name"}, `([0-9]+)_(.+)`}, + ":id(.+)_cms.html": {true, []string{":id"}, `(.+)_cms.html`}, + ":id(.+)_cms\\.html": {true, []string{":id"}, `(.+)_cms\.html`}, + "cms_:id(.+)_:page(.+).html": {true, []string{":id", ":page"}, `cms_(.+)_(.+).html`}, + `:app(a|b|c)`: {true, []string{":app"}, `(a|b|c)`}, + `:app\((a|b|c)\)`: {true, []string{":app"}, `(.+)\((a|b|c)\)`}, + } + + for pattern, v := range items { + b, w, r := splitSegment(pattern) + if b != v.isReg || r != v.regStr || strings.Join(w, ",") != strings.Join(v.params, ",") { + t.Fatalf("%s should return %t,%s,%q, got %t,%s,%q", pattern, v.isReg, v.params, v.regStr, b, w, r) + } + } +} diff --git a/unregroute_test.go b/server/web/unregroute_test.go similarity index 79% rename from unregroute_test.go rename to server/web/unregroute_test.go index 08b1b77b22..9265011b0d 100644 --- a/unregroute_test.go +++ b/server/web/unregroute_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package beego +package web import ( "net/http" @@ -27,13 +27,14 @@ import ( // that embed parent routers. // -const contentRootOriginal = "ok-original-root" -const contentLevel1Original = "ok-original-level1" -const contentLevel2Original = "ok-original-level2" - -const contentRootReplacement = "ok-replacement-root" -const contentLevel1Replacement = "ok-replacement-level1" -const contentLevel2Replacement = "ok-replacement-level2" +const ( + contentRootOriginal = "ok-original-root" + contentLevel1Original = "ok-original-level1" + contentLevel2Original = "ok-original-level2" + contentRootReplacement = "ok-replacement-root" + contentLevel1Replacement = "ok-replacement-level1" + contentLevel2Replacement = "ok-replacement-level2" +) // TestPreUnregController will supply content for the original routes, // before unregistration @@ -44,9 +45,11 @@ type TestPreUnregController struct { func (tc *TestPreUnregController) GetFixedRoot() { tc.Ctx.Output.Body([]byte(contentRootOriginal)) } + func (tc *TestPreUnregController) GetFixedLevel1() { tc.Ctx.Output.Body([]byte(contentLevel1Original)) } + func (tc *TestPreUnregController) GetFixedLevel2() { tc.Ctx.Output.Body([]byte(contentLevel2Original)) } @@ -60,9 +63,11 @@ type TestPostUnregController struct { func (tc *TestPostUnregController) GetFixedRoot() { tc.Ctx.Output.Body([]byte(contentRootReplacement)) } + func (tc *TestPostUnregController) GetFixedLevel1() { tc.Ctx.Output.Body([]byte(contentLevel1Replacement)) } + func (tc *TestPostUnregController) GetFixedLevel2() { tc.Ctx.Output.Body([]byte(contentLevel2Replacement)) } @@ -71,13 +76,12 @@ func (tc *TestPostUnregController) GetFixedLevel2() { // In this case, for a path like "/level1/level2" or "/level1", those actions // should remain intact, and continue to serve the original content. func TestUnregisterFixedRouteRoot(t *testing.T) { - - var method = "GET" + method := "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, "Test original root", @@ -96,7 +100,7 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Replace the root path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/", &TestPostUnregController{}, "get:GetFixedRoot") + handler.Add("/", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedRoot")) // Test replacement root (expect change) testHelperFnContentCheck(t, handler, "Test replacement root (expect change)", method, "/", contentRootReplacement) @@ -106,20 +110,18 @@ func TestUnregisterFixedRouteRoot(t *testing.T) { // Test level 2 (expect no change from the original) testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) - } // TestUnregisterFixedRouteLevel1 replaces just the "/level1" fixed route path. // In this case, for a path like "/level1/level2" or "/", those actions // should remain intact, and continue to serve the original content. func TestUnregisterFixedRouteLevel1(t *testing.T) { - - var method = "GET" + method := "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -146,7 +148,7 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Replace the "level1" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1", &TestPostUnregController{}, "get:GetFixedLevel1") + handler.Add("/level1", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel1")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -156,20 +158,18 @@ func TestUnregisterFixedRouteLevel1(t *testing.T) { // Test level 2 (expect no change from the original) testHelperFnContentCheck(t, handler, "Test level 2 (expect no change from the original)", method, "/level1/level2", contentLevel2Original) - } // TestUnregisterFixedRouteLevel2 unregisters just the "/level1/level2" fixed // route path. In this case, for a path like "/level1" or "/", those actions // should remain intact, and continue to serve the original content. func TestUnregisterFixedRouteLevel2(t *testing.T) { - - var method = "GET" + method := "GET" handler := NewControllerRegister() - handler.Add("/", &TestPreUnregController{}, "get:GetFixedRoot") - handler.Add("/level1", &TestPreUnregController{}, "get:GetFixedLevel1") - handler.Add("/level1/level2", &TestPreUnregController{}, "get:GetFixedLevel2") + handler.Add("/", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedRoot")) + handler.Add("/level1", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel1")) + handler.Add("/level1/level2", &TestPreUnregController{}, WithRouterMethods(&TestPreUnregController{}, "get:GetFixedLevel2")) // Test original root testHelperFnContentCheck(t, handler, @@ -196,7 +196,7 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Replace the "/level1/level2" path TestPreUnregController action with the action from // TestPostUnregController - handler.Add("/level1/level2", &TestPostUnregController{}, "get:GetFixedLevel2") + handler.Add("/level1/level2", &TestPostUnregController{}, WithRouterMethods(&TestPostUnregController{}, "get:GetFixedLevel2")) // Test replacement root (expect no change from the original) testHelperFnContentCheck(t, handler, "Test replacement root (expect no change from the original)", method, "/", contentRootOriginal) @@ -206,12 +206,11 @@ func TestUnregisterFixedRouteLevel2(t *testing.T) { // Test level 2 (expect change) testHelperFnContentCheck(t, handler, "Test level 2 (expect change)", method, "/level1/level2", contentLevel2Replacement) - } func testHelperFnContentCheck(t *testing.T, handler *ControllerRegister, - testName, method, path, expectedBodyContent string) { - + testName, method, path, expectedBodyContent string, +) { r, err := http.NewRequest(method, path, nil) if err != nil { t.Errorf("httpRecorderBodyTest NewRequest error: %v", err) diff --git a/session/README.md b/session/README.md deleted file mode 100644 index 6d0a297e3c..0000000000 --- a/session/README.md +++ /dev/null @@ -1,114 +0,0 @@ -session -============== - -session is a Go session manager. It can use many session providers. Just like the `database/sql` and `database/sql/driver`. - -## How to install? - - go get github.com/astaxie/beego/session - - -## What providers are supported? - -As of now this session manager support memory, file, Redis and MySQL. - - -## How to use it? - -First you must import it - - import ( - "github.com/astaxie/beego/session" - ) - -Then in you web app init the global session manager - - var globalSessions *session.Manager - -* Use **memory** as provider: - - func init() { - globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`) - go globalSessions.GC() - } - -* Use **file** as provider, the last param is the path where you want file to be stored: - - func init() { - globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"./tmp"}`) - go globalSessions.GC() - } - -* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: - - func init() { - globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:6379,100,astaxie"}`) - go globalSessions.GC() - } - -* Use **MySQL** as provider, the last param is the DSN, learn more from [mysql](https://github.com/go-sql-driver/mysql#dsn-data-source-name): - - func init() { - globalSessions, _ = session.NewManager( - "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"username:password@protocol(address)/dbname?param=value"}`) - go globalSessions.GC() - } - -* Use **Cookie** as provider: - - func init() { - globalSessions, _ = session.NewManager( - "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`) - go globalSessions.GC() - } - - -Finally in the handlerfunc you can use it like this - - func login(w http.ResponseWriter, r *http.Request) { - sess := globalSessions.SessionStart(w, r) - defer sess.SessionRelease(w) - username := sess.Get("username") - fmt.Println(username) - if r.Method == "GET" { - t, _ := template.ParseFiles("login.gtpl") - t.Execute(w, nil) - } else { - fmt.Println("username:", r.Form["username"]) - sess.Set("username", r.Form["username"]) - fmt.Println("password:", r.Form["password"]) - } - } - - -## How to write own provider? - -When you develop a web app, maybe you want to write own provider because you must meet the requirements. - -Writing a provider is easy. You only need to define two struct types -(Session and Provider), which satisfy the interface definition. -Maybe you will find the **memory** provider is a good example. - - type SessionStore interface { - Set(key, value interface{}) error //set session value - Get(key interface{}) interface{} //get session value - Delete(key interface{}) error //delete session value - SessionID() string //back current sessionID - SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data - Flush() error //delete all data - } - - type Provider interface { - SessionInit(gclifetime int64, config string) error - SessionRead(sid string) (SessionStore, error) - SessionExist(sid string) bool - SessionRegenerate(oldsid, sid string) (SessionStore, error) - SessionDestroy(sid string) error - SessionAll() int //get all active session - SessionGC() - } - - -## LICENSE - -BSD License http://creativecommons.org/licenses/BSD/ diff --git a/session/ledis/ledis_session.go b/session/ledis/ledis_session.go deleted file mode 100644 index c0d4bf82b4..0000000000 --- a/session/ledis/ledis_session.go +++ /dev/null @@ -1,172 +0,0 @@ -// Package ledis provide session Provider -package ledis - -import ( - "net/http" - "strconv" - "strings" - "sync" - - "github.com/astaxie/beego/session" - "github.com/siddontang/ledisdb/config" - "github.com/siddontang/ledisdb/ledis" -) - -var ( - ledispder = &Provider{} - c *ledis.DB -) - -// SessionStore ledis session store -type SessionStore struct { - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in ledis session -func (ls *SessionStore) Set(key, value interface{}) error { - ls.lock.Lock() - defer ls.lock.Unlock() - ls.values[key] = value - return nil -} - -// Get value in ledis session -func (ls *SessionStore) Get(key interface{}) interface{} { - ls.lock.RLock() - defer ls.lock.RUnlock() - if v, ok := ls.values[key]; ok { - return v - } - return nil -} - -// Delete value in ledis session -func (ls *SessionStore) Delete(key interface{}) error { - ls.lock.Lock() - defer ls.lock.Unlock() - delete(ls.values, key) - return nil -} - -// Flush clear all values in ledis session -func (ls *SessionStore) Flush() error { - ls.lock.Lock() - defer ls.lock.Unlock() - ls.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get ledis session id -func (ls *SessionStore) SessionID() string { - return ls.sid -} - -// SessionRelease save session values to ledis -func (ls *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(ls.values) - if err != nil { - return - } - c.Set([]byte(ls.sid), b) - c.Expire([]byte(ls.sid), ls.maxlifetime) -} - -// Provider ledis session provider -type Provider struct { - maxlifetime int64 - savePath string - db int -} - -// SessionInit init ledis session -// savepath like ledis server saveDataPath,pool size -// e.g. 127.0.0.1:6379,100,astaxie -func (lp *Provider) SessionInit(maxlifetime int64, savePath string) error { - var err error - lp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) == 1 { - lp.savePath = configs[0] - } else if len(configs) == 2 { - lp.savePath = configs[0] - lp.db, err = strconv.Atoi(configs[1]) - if err != nil { - return err - } - } - cfg := new(config.Config) - cfg.DataDir = lp.savePath - - var ledisInstance *ledis.Ledis - ledisInstance, err = ledis.Open(cfg) - if err != nil { - return err - } - c, err = ledisInstance.Select(lp.db) - return err -} - -// SessionRead read ledis session by sid -func (lp *Provider) SessionRead(sid string) (session.Store, error) { - var ( - kv map[interface{}]interface{} - err error - ) - - kvs, _ := c.Get([]byte(sid)) - - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob(kvs); err != nil { - return nil, err - } - } - - ls := &SessionStore{sid: sid, values: kv, maxlifetime: lp.maxlifetime} - return ls, nil -} - -// SessionExist check ledis session exist by sid -func (lp *Provider) SessionExist(sid string) bool { - count, _ := c.Exists([]byte(sid)) - return count != 0 -} - -// SessionRegenerate generate new sid for ledis session -func (lp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - count, _ := c.Exists([]byte(sid)) - if count == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set([]byte(sid), []byte("")) - c.Expire([]byte(sid), lp.maxlifetime) - } else { - data, _ := c.Get([]byte(oldsid)) - c.Set([]byte(sid), data) - c.Expire([]byte(sid), lp.maxlifetime) - } - return lp.SessionRead(sid) -} - -// SessionDestroy delete ledis session by id -func (lp *Provider) SessionDestroy(sid string) error { - c.Del([]byte(sid)) - return nil -} - -// SessionGC Impelment method, no used. -func (lp *Provider) SessionGC() { -} - -// SessionAll return all active session -func (lp *Provider) SessionAll() int { - return 0 -} -func init() { - session.Register("ledis", ledispder) -} diff --git a/session/redis/sess_redis.go b/session/redis/sess_redis.go deleted file mode 100644 index 5c382d61e4..0000000000 --- a/session/redis/sess_redis.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/gomodule/redigo/redis -// -// go install github.com/gomodule/redigo/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package redis - -import ( - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/astaxie/beego/session" - - "github.com/gomodule/redigo/redis" -) - -var redispder = &Provider{} - -// MaxPoolSize redis max pool size -var MaxPoolSize = 100 - -// SessionStore redis session store -type SessionStore struct { - p *redis.Pool - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p.Get() - defer c.Close() - c.Do("SETEX", rs.sid, rs.maxlifetime, string(b)) -} - -// Provider redis session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *redis.Pool -} - -// SessionInit init redis session -// savepath like redis server addr,pool size,password,dbnum,IdleTimeout second -// e.g. 127.0.0.1:6379,100,astaxie,0,30 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = MaxPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - var idleTimeout time.Duration = 0 - if len(configs) > 4 { - timeout, err := strconv.Atoi(configs[4]) - if err == nil && timeout > 0 { - idleTimeout = time.Duration(timeout) * time.Second - } - } - rp.poollist = &redis.Pool{ - Dial: func() (redis.Conn, error) { - c, err := redis.Dial("tcp", rp.savePath) - if err != nil { - return nil, err - } - if rp.password != "" { - if _, err = c.Do("AUTH", rp.password); err != nil { - c.Close() - return nil, err - } - } - // some redis proxy such as twemproxy is not support select command - if rp.dbNum > 0 { - _, err = c.Do("SELECT", rp.dbNum) - if err != nil { - c.Close() - return nil, err - } - } - return c, err - }, - MaxIdle: rp.poolsize, - } - - rp.poollist.IdleTimeout = idleTimeout - - return rp.poollist.Get().Err() -} - -// SessionRead read redis session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - c := rp.poollist.Get() - defer c.Close() - - var kv map[interface{}]interface{} - - kvs, err := redis.String(c.Do("GET", sid)) - if err != nil && err != redis.ErrNil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist.Get() - defer c.Close() - - if existed, err := redis.Int(c.Do("EXISTS", sid)); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist.Get() - defer c.Close() - - if existed, _ := redis.Int(c.Do("EXISTS", oldsid)); existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Do("SET", sid, "", "EX", rp.maxlifetime) - } else { - c.Do("RENAME", oldsid, sid) - c.Do("EXPIRE", sid, rp.maxlifetime) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist.Get() - defer c.Close() - - c.Do("DEL", sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis", redispder) -} diff --git a/session/redis_cluster/redis_cluster.go b/session/redis_cluster/redis_cluster.go deleted file mode 100644 index 2fe300df1b..0000000000 --- a/session/redis_cluster/redis_cluster.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/go-redis/redis -// -// go install github.com/go-redis/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis_cluster" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis_cluster", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:7070;127.0.0.1:7071"}``) -// go globalSessions.GC() -// } -// -// more docs: http://beego.me/docs/module/session.md -package redis_cluster -import ( - "net/http" - "strconv" - "strings" - "sync" - "github.com/astaxie/beego/session" - rediss "github.com/go-redis/redis" - "time" -) - -var redispder = &Provider{} - -// MaxPoolSize redis_cluster max pool size -var MaxPoolSize = 1000 - -// SessionStore redis_cluster session store -type SessionStore struct { - p *rediss.ClusterClient - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis_cluster session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis_cluster session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis_cluster session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis_cluster session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis_cluster session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis_cluster -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime) * time.Second) -} - -// Provider redis_cluster session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *rediss.ClusterClient -} - -// SessionInit init redis_cluster session -// savepath like redis server addr,pool size,password,dbnum -// e.g. 127.0.0.1:6379;127.0.0.1:6380,100,test,0 -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = MaxPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = MaxPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - - rp.poollist = rediss.NewClusterClient(&rediss.ClusterOptions{ - Addrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - }) - return rp.poollist.Ping().Err() -} - -// SessionRead read redis_cluster session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - var kv map[interface{}]interface{} - kvs, err := rp.poollist.Get(sid).Result() - if err != nil && err != rediss.Nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis_cluster session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist - if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis_cluster session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist - - if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime) * time.Second) - } else { - c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime) * time.Second) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist - c.Del(sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis_cluster", redispder) -} diff --git a/session/redis_sentinel/sess_redis_sentinel.go b/session/redis_sentinel/sess_redis_sentinel.go deleted file mode 100644 index 6ecb297707..0000000000 --- a/session/redis_sentinel/sess_redis_sentinel.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package redis for session provider -// -// depend on github.com/go-redis/redis -// -// go install github.com/go-redis/redis -// -// Usage: -// import( -// _ "github.com/astaxie/beego/session/redis_sentinel" -// "github.com/astaxie/beego/session" -// ) -// -// func init() { -// globalSessions, _ = session.NewManager("redis_sentinel", ``{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig":"127.0.0.1:26379;127.0.0.2:26379"}``) -// go globalSessions.GC() -// } -// -// more detail about params: please check the notes on the function SessionInit in this package -package redis_sentinel - -import ( - "github.com/astaxie/beego/session" - "github.com/go-redis/redis" - "net/http" - "strconv" - "strings" - "sync" - "time" -) - -var redispder = &Provider{} - -// DefaultPoolSize redis_sentinel default pool size -var DefaultPoolSize = 100 - -// SessionStore redis_sentinel session store -type SessionStore struct { - p *redis.Client - sid string - lock sync.RWMutex - values map[interface{}]interface{} - maxlifetime int64 -} - -// Set value in redis_sentinel session -func (rs *SessionStore) Set(key, value interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values[key] = value - return nil -} - -// Get value in redis_sentinel session -func (rs *SessionStore) Get(key interface{}) interface{} { - rs.lock.RLock() - defer rs.lock.RUnlock() - if v, ok := rs.values[key]; ok { - return v - } - return nil -} - -// Delete value in redis_sentinel session -func (rs *SessionStore) Delete(key interface{}) error { - rs.lock.Lock() - defer rs.lock.Unlock() - delete(rs.values, key) - return nil -} - -// Flush clear all values in redis_sentinel session -func (rs *SessionStore) Flush() error { - rs.lock.Lock() - defer rs.lock.Unlock() - rs.values = make(map[interface{}]interface{}) - return nil -} - -// SessionID get redis_sentinel session id -func (rs *SessionStore) SessionID() string { - return rs.sid -} - -// SessionRelease save session values to redis_sentinel -func (rs *SessionStore) SessionRelease(w http.ResponseWriter) { - b, err := session.EncodeGob(rs.values) - if err != nil { - return - } - c := rs.p - c.Set(rs.sid, string(b), time.Duration(rs.maxlifetime)*time.Second) -} - -// Provider redis_sentinel session provider -type Provider struct { - maxlifetime int64 - savePath string - poolsize int - password string - dbNum int - poollist *redis.Client - masterName string -} - -// SessionInit init redis_sentinel session -// savepath like redis sentinel addr,pool size,password,dbnum,masterName -// e.g. 127.0.0.1:26379;127.0.0.2:26379,100,1qaz2wsx,0,mymaster -func (rp *Provider) SessionInit(maxlifetime int64, savePath string) error { - rp.maxlifetime = maxlifetime - configs := strings.Split(savePath, ",") - if len(configs) > 0 { - rp.savePath = configs[0] - } - if len(configs) > 1 { - poolsize, err := strconv.Atoi(configs[1]) - if err != nil || poolsize < 0 { - rp.poolsize = DefaultPoolSize - } else { - rp.poolsize = poolsize - } - } else { - rp.poolsize = DefaultPoolSize - } - if len(configs) > 2 { - rp.password = configs[2] - } - if len(configs) > 3 { - dbnum, err := strconv.Atoi(configs[3]) - if err != nil || dbnum < 0 { - rp.dbNum = 0 - } else { - rp.dbNum = dbnum - } - } else { - rp.dbNum = 0 - } - if len(configs) > 4 { - if configs[4] != "" { - rp.masterName = configs[4] - } else { - rp.masterName = "mymaster" - } - } else { - rp.masterName = "mymaster" - } - - rp.poollist = redis.NewFailoverClient(&redis.FailoverOptions{ - SentinelAddrs: strings.Split(rp.savePath, ";"), - Password: rp.password, - PoolSize: rp.poolsize, - DB: rp.dbNum, - MasterName: rp.masterName, - }) - - return rp.poollist.Ping().Err() -} - -// SessionRead read redis_sentinel session by sid -func (rp *Provider) SessionRead(sid string) (session.Store, error) { - var kv map[interface{}]interface{} - kvs, err := rp.poollist.Get(sid).Result() - if err != nil && err != redis.Nil { - return nil, err - } - if len(kvs) == 0 { - kv = make(map[interface{}]interface{}) - } else { - if kv, err = session.DecodeGob([]byte(kvs)); err != nil { - return nil, err - } - } - - rs := &SessionStore{p: rp.poollist, sid: sid, values: kv, maxlifetime: rp.maxlifetime} - return rs, nil -} - -// SessionExist check redis_sentinel session exist by sid -func (rp *Provider) SessionExist(sid string) bool { - c := rp.poollist - if existed, err := c.Exists(sid).Result(); err != nil || existed == 0 { - return false - } - return true -} - -// SessionRegenerate generate new sid for redis_sentinel session -func (rp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { - c := rp.poollist - - if existed, err := c.Exists(oldsid).Result(); err != nil || existed == 0 { - // oldsid doesn't exists, set the new sid directly - // ignore error here, since if it return error - // the existed value will be 0 - c.Set(sid, "", time.Duration(rp.maxlifetime)*time.Second) - } else { - c.Rename(oldsid, sid) - c.Expire(sid, time.Duration(rp.maxlifetime)*time.Second) - } - return rp.SessionRead(sid) -} - -// SessionDestroy delete redis session by id -func (rp *Provider) SessionDestroy(sid string) error { - c := rp.poollist - c.Del(sid) - return nil -} - -// SessionGC Impelment method, no used. -func (rp *Provider) SessionGC() { -} - -// SessionAll return all activeSession -func (rp *Provider) SessionAll() int { - return 0 -} - -func init() { - session.Register("redis_sentinel", redispder) -} diff --git a/session/redis_sentinel/sess_redis_sentinel_test.go b/session/redis_sentinel/sess_redis_sentinel_test.go deleted file mode 100644 index fd4155c632..0000000000 --- a/session/redis_sentinel/sess_redis_sentinel_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package redis_sentinel - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/astaxie/beego/session" -) - -func TestRedisSentinel(t *testing.T) { - sessionConfig := &session.ManagerConfig{ - CookieName: "gosessionid", - EnableSetCookie: true, - Gclifetime: 3600, - Maxlifetime: 3600, - Secure: false, - CookieLifeTime: 3600, - ProviderConfig: "127.0.0.1:6379,100,,0,master", - } - globalSessions, e := session.NewManager("redis_sentinel", sessionConfig) - if e != nil { - t.Log(e) - return - } - //todo test if e==nil - go globalSessions.GC() - - r, _ := http.NewRequest("GET", "/", nil) - w := httptest.NewRecorder() - - sess, err := globalSessions.SessionStart(w, r) - if err != nil { - t.Fatal("session start failed:", err) - } - defer sess.SessionRelease(w) - - // SET AND GET - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set username failed:", err) - } - username := sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } - - // DELETE - err = sess.Delete("username") - if err != nil { - t.Fatal("delete username failed:", err) - } - username = sess.Get("username") - if username != nil { - t.Fatal("delete username failed") - } - - // FLUSH - err = sess.Set("username", "astaxie") - if err != nil { - t.Fatal("set failed:", err) - } - err = sess.Set("password", "1qaz2wsx") - if err != nil { - t.Fatal("set failed:", err) - } - username = sess.Get("username") - if username != "astaxie" { - t.Fatal("get username failed") - } - password := sess.Get("password") - if password != "1qaz2wsx" { - t.Fatal("get password failed") - } - err = sess.Flush() - if err != nil { - t.Fatal("flush failed:", err) - } - username = sess.Get("username") - if username != nil { - t.Fatal("flush failed") - } - password = sess.Get("password") - if password != nil { - t.Fatal("flush failed") - } - - sess.SessionRelease(w) - -} diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000000..dbd5d6ae63 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,7 @@ +sonar.organization=beego +sonar.projectKey=beego_beego + +# relative paths to source directories. More details and properties are described +# in https://sonarcloud.io/documentation/project-administration/narrowing-the-focus/ +sonar.sources=. +sonar.exclusions=**/*_test.go diff --git a/task/governor_command.go b/task/governor_command.go new file mode 100644 index 0000000000..705d2d108f --- /dev/null +++ b/task/governor_command.go @@ -0,0 +1,88 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "context" + "errors" + "fmt" + "html/template" + + "github.com/beego/beego/v2/core/admin" +) + +type listTaskCommand struct{} + +func (l *listTaskCommand) Execute(params ...interface{}) *admin.Result { + resultList := make([][]string, 0, len(globalTaskManager.adminTaskList)) + for tname, tk := range globalTaskManager.adminTaskList { + result := []string{ + template.HTMLEscapeString(tname), + template.HTMLEscapeString(tk.GetSpec(nil)), + template.HTMLEscapeString(tk.GetStatus(nil)), + template.HTMLEscapeString(tk.GetPrev(context.Background()).String()), + } + resultList = append(resultList, result) + } + + return &admin.Result{ + Status: 200, + Content: resultList, + } +} + +type runTaskCommand struct{} + +func (r *runTaskCommand) Execute(params ...interface{}) *admin.Result { + if len(params) == 0 { + return &admin.Result{ + Status: 400, + Error: errors.New("task name not passed"), + } + } + + tn, ok := params[0].(string) + + if !ok { + return &admin.Result{ + Status: 400, + Error: errors.New("parameter is invalid"), + } + } + + if t, ok := globalTaskManager.adminTaskList[tn]; ok { + err := t.Run(context.Background()) + if err != nil { + return &admin.Result{ + Status: 500, + Error: err, + } + } + return &admin.Result{ + Status: 200, + Content: t.GetStatus(context.Background()), + } + } else { + return &admin.Result{ + Status: 400, + Error: fmt.Errorf("task with name %s not found", tn), + } + } +} + +func registerCommands() { + admin.RegisterCommand("task", "list", &listTaskCommand{}) + admin.RegisterCommand("task", "run", &runTaskCommand{}) +} diff --git a/task/governor_command_test.go b/task/governor_command_test.go new file mode 100644 index 0000000000..c3547cdfe4 --- /dev/null +++ b/task/governor_command_test.go @@ -0,0 +1,115 @@ +// Copyright 2020 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type countTask struct { + cnt int + mockErr error +} + +func (c *countTask) GetSpec(ctx context.Context) string { + return "AAA" +} + +func (c *countTask) GetStatus(ctx context.Context) string { + return "SUCCESS" +} + +func (c *countTask) Run(ctx context.Context) error { + c.cnt++ + return c.mockErr +} + +func (c *countTask) SetNext(ctx context.Context, time time.Time) { +} + +func (c *countTask) GetNext(ctx context.Context) time.Time { + return time.Now() +} + +func (c *countTask) SetPrev(ctx context.Context, time time.Time) { +} + +func (c *countTask) GetPrev(ctx context.Context) time.Time { + return time.Now() +} + +func (c *countTask) GetTimeout(ctx context.Context) time.Duration { + return 0 +} + +func TestRunTaskCommand_Execute(t *testing.T) { + task := &countTask{} + AddTask("count", task) + + cmd := &runTaskCommand{} + + res := cmd.Execute() + assert.NotNil(t, res) + assert.NotNil(t, res.Error) + assert.Equal(t, "task name not passed", res.Error.Error()) + + res = cmd.Execute(10) + assert.NotNil(t, res) + assert.NotNil(t, res.Error) + assert.Equal(t, "parameter is invalid", res.Error.Error()) + + res = cmd.Execute("CCCC") + assert.NotNil(t, res) + assert.NotNil(t, res.Error) + assert.Equal(t, "task with name CCCC not found", res.Error.Error()) + + res = cmd.Execute("count") + assert.NotNil(t, res) + assert.True(t, res.IsSuccess()) + + task.mockErr = errors.New("mock error") + res = cmd.Execute("count") + assert.NotNil(t, res) + assert.NotNil(t, res.Error) + assert.Equal(t, "mock error", res.Error.Error()) +} + +func TestListTaskCommand_Execute(t *testing.T) { + task := &countTask{} + + cmd := &listTaskCommand{} + + res := cmd.Execute() + + assert.True(t, res.IsSuccess()) + + _, ok := res.Content.([][]string) + assert.True(t, ok) + + AddTask("count", task) + + res = cmd.Execute() + + assert.True(t, res.IsSuccess()) + + rl, ok := res.Content.([][]string) + assert.True(t, ok) + assert.Equal(t, 1, len(rl)) +} diff --git a/toolbox/task.go b/task/task.go similarity index 61% rename from toolbox/task.go rename to task/task.go index d134302383..870210eccc 100644 --- a/toolbox/task.go +++ b/task/task.go @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package toolbox +package task import ( + "context" "log" "math" "sort" @@ -30,18 +31,34 @@ type bounds struct { names map[string]uint } -// The bounds for each field. -var ( - AdminTaskList map[string]Tasker - taskLock sync.Mutex +type taskManager struct { + adminTaskList map[string]Tasker + taskLock sync.RWMutex stop chan bool changed chan bool - isstart bool - seconds = bounds{0, 59, nil} - minutes = bounds{0, 59, nil} - hours = bounds{0, 23, nil} - days = bounds{1, 31, nil} - months = bounds{1, 12, map[string]uint{ + started bool + wait sync.WaitGroup +} + +func newTaskManager() *taskManager { + return &taskManager{ + adminTaskList: make(map[string]Tasker), + taskLock: sync.RWMutex{}, + stop: make(chan bool), + changed: make(chan bool), + started: false, + } +} + +// The bounds for each field. +var ( + globalTaskManager *taskManager + + seconds = bounds{0, 59, nil} + minutes = bounds{0, 59, nil} + hours = bounds{0, 23, nil} + days = bounds{1, 31, nil} + months = bounds{1, 12, map[string]uint{ "jan": 1, "feb": 2, "mar": 3, @@ -82,17 +99,18 @@ type Schedule struct { } // TaskFunc task func type -type TaskFunc func() error +type TaskFunc func(ctx context.Context) error // Tasker task interface type Tasker interface { - GetSpec() string - GetStatus() string - Run() error - SetNext(time.Time) - GetNext() time.Time - SetPrev(time.Time) - GetPrev() time.Time + GetSpec(ctx context.Context) string + GetStatus(ctx context.Context) string + Run(ctx context.Context) error + SetNext(context.Context, time.Time) + GetNext(ctx context.Context) time.Time + SetPrev(context.Context, time.Time) + GetPrev(ctx context.Context) time.Time + GetTimeout(ctx context.Context) time.Duration } // task error @@ -102,6 +120,8 @@ type taskerr struct { } // Task task struct +// It's not a thread-safe structure. +// Only nearest errors will be saved in ErrList type Task struct { Taskname string Spec *Schedule @@ -109,68 +129,105 @@ type Task struct { DoFunc TaskFunc Prev time.Time Next time.Time - Errlist []*taskerr // like errtime:errinfo - ErrLimit int // max length for the errlist, 0 stand for no limit + Timeout time.Duration // timeout duration + Errlist []*taskerr // like errtime:errinfo + ErrLimit int // max length for the errlist, 0 stand for no limit + errCnt int // records the error count during the execution } // NewTask add new task with name, time and func -func NewTask(tname string, spec string, f TaskFunc) *Task { - +func NewTask(tname string, spec string, f TaskFunc, opts ...Option) *Task { task := &Task{ Taskname: tname, DoFunc: f, + // Make configurable ErrLimit: 100, SpecStr: spec, + // we only store the pointer, so it won't use too many space + Errlist: make([]*taskerr, 100, 100), + } + + for _, opt := range opts { + opt.apply(task) } + task.SetCron(spec) return task } // GetSpec get spec string -func (t *Task) GetSpec() string { +func (t *Task) GetSpec(context.Context) string { return t.SpecStr } // GetStatus get current task status -func (t *Task) GetStatus() string { +func (t *Task) GetStatus(context.Context) string { var str string for _, v := range t.Errlist { + if v == nil { + continue + } str += v.t.String() + ":" + v.errinfo + "
" } return str } // Run run all tasks -func (t *Task) Run() error { - err := t.DoFunc() +func (t *Task) Run(ctx context.Context) error { + err := t.DoFunc(ctx) if err != nil { - if t.ErrLimit > 0 && t.ErrLimit > len(t.Errlist) { - t.Errlist = append(t.Errlist, &taskerr{t: t.Next, errinfo: err.Error()}) - } + index := t.errCnt % t.ErrLimit + t.Errlist[index] = &taskerr{t: t.Next, errinfo: err.Error()} + t.errCnt++ } return err } // SetNext set next time for this task -func (t *Task) SetNext(now time.Time) { +func (t *Task) SetNext(ctx context.Context, now time.Time) { t.Next = t.Spec.Next(now) } // GetNext get the next call time of this task -func (t *Task) GetNext() time.Time { +func (t *Task) GetNext(context.Context) time.Time { return t.Next } // SetPrev set prev time of this task -func (t *Task) SetPrev(now time.Time) { +func (t *Task) SetPrev(ctx context.Context, now time.Time) { t.Prev = now } // GetPrev get prev time of this task -func (t *Task) GetPrev() time.Time { +func (t *Task) GetPrev(context.Context) time.Time { return t.Prev } +// GetTimeout get timeout duration of this task +func (t *Task) GetTimeout(context.Context) time.Duration { + return t.Timeout +} + +// Option interface +type Option interface { + apply(*Task) +} + +// optionFunc return a function to set task element +type optionFunc func(*Task) + +// apply option to task +func (f optionFunc) apply(t *Task) { + f(t) +} + +// TimeoutOption return an option to set timeout duration for task +func TimeoutOption(timeout time.Duration) Option { + return optionFunc(func(t *Task) { + t.Timeout = timeout + }) +} + // six columns mean: // second:0-59 // minute:0-59 @@ -180,11 +237,16 @@ func (t *Task) GetPrev() time.Time { // week:0-6(0 means Sunday) // SetCron some signals: -// *: any time -// ,:  separate signal -//   -:duration -// /n : do as n times of time duration -///////////////////////////////////////////////////////// +// +// *: any time +// ,:  separate signal +// +//    -:duration +// +// /n : do as n times of time duration +// +// /////////////////////////////////////////////////////// +// // 0/30 * * * * * every 30s // 0 43 21 * * * 21:43 // 0 15 05 * * *    05:15 @@ -293,7 +355,6 @@ func (t *Task) parseSpec(spec string) *Schedule { // Next set schedule to next time func (s *Schedule) Next(t time.Time) time.Time { - // Start at the earliest possible time (the upcoming second). t = t.Add(1*time.Second - time.Duration(t.Nanosecond())*time.Nanosecond) @@ -391,89 +452,231 @@ func dayMatches(s *Schedule, t time.Time) bool { // StartTask start all tasks func StartTask() { - taskLock.Lock() - defer taskLock.Unlock() - if isstart { - //If already started, no need to start another goroutine. + globalTaskManager.StartTask() +} + +// StopTask stop all tasks +func StopTask() { + globalTaskManager.StopTask() +} + +// AddTask add task with name +func AddTask(taskName string, t Tasker) { + globalTaskManager.AddTask(taskName, t) +} + +// DeleteTask delete task with name +func DeleteTask(taskName string) { + globalTaskManager.DeleteTask(taskName) +} + +// ClearTask clear all tasks +func ClearTask() { + globalTaskManager.ClearTask() +} + +// GetAllTasks get all tasks +func GetAllTasks() []Tasker { + return globalTaskManager.GetAllTasks() +} + +// GracefulShutdown wait all task done +func GracefulShutdown() <-chan struct{} { + return globalTaskManager.GracefulShutdown() +} + +// StartTask start all tasks +func (m *taskManager) StartTask() { + m.taskLock.Lock() + defer m.taskLock.Unlock() + if m.started { + // If already started, no need to start another goroutine. return } - isstart = true - go run() + m.started = true + + registerCommands() + go m.run() } -func run() { +func (m *taskManager) run() { now := time.Now().Local() - for _, t := range AdminTaskList { - t.SetNext(now) - } + // first run the tasks, so set all tasks next run time. + m.setTasksStartTime(now) for { - sortList := NewMapSorter(AdminTaskList) + // we only use RLock here because NewMapSorter copy the reference, do not change any thing + // here, we sort all task and get first task running time (effective). + m.taskLock.RLock() + sortList := NewMapSorter(m.adminTaskList) + m.taskLock.RUnlock() sortList.Sort() var effective time.Time - if len(AdminTaskList) == 0 || sortList.Vals[0].GetNext().IsZero() { + if len(m.adminTaskList) == 0 || sortList.Vals[0].GetNext(context.Background()).IsZero() { // If there are no entries yet, just sleep - it still handles new entries // and stop requests. effective = now.AddDate(10, 0, 0) } else { - effective = sortList.Vals[0].GetNext() + effective = sortList.Vals[0].GetNext(context.Background()) } + select { - case now = <-time.After(effective.Sub(now)): - // Run every entry whose next time was this effective time. - for _, e := range sortList.Vals { - if e.GetNext() != effective { - break - } - go e.Run() - e.SetPrev(e.GetNext()) - e.SetNext(effective) - } + case now = <-time.After(effective.Sub(now)): // wait for effective time + m.runNextTasks(sortList, effective) continue - case <-changed: + case <-m.changed: // tasks have been changed, set all tasks run again now now = time.Now().Local() - for _, t := range AdminTaskList { - t.SetNext(now) - } + m.setTasksStartTime(now) continue - case <-stop: + case <-m.stop: // manager is stopped, and mark manager is stopped + m.markManagerStop() return } } } -// StopTask stop all tasks -func StopTask() { - taskLock.Lock() - defer taskLock.Unlock() - if isstart { - isstart = false - stop <- true +// setTasksStartTime is set all tasks next running time +func (m *taskManager) setTasksStartTime(now time.Time) { + m.taskLock.Lock() + for _, task := range m.adminTaskList { + task.SetNext(context.Background(), now) + } + m.taskLock.Unlock() +} + +// markManagerStop it sets manager to be stopped +func (m *taskManager) markManagerStop() { + m.taskLock.Lock() + if m.started { + m.started = false } + m.taskLock.Unlock() +} + +// runNextTasks it runs next task which next run time is equal to effective +func (m *taskManager) runNextTasks(sortList *MapSorter, effective time.Time) { + // Run every entry whose next time was this effective time. + i := 0 + for _, e := range sortList.Vals { + i++ + if e.GetNext(context.Background()) != effective { + break + } + + // check if timeout is on, if yes passing the timeout context + ctx := context.Background() + m.wait.Add(1) + if duration := e.GetTimeout(ctx); duration != 0 { + go func(e Tasker) { + defer m.wait.Done() + ctx, cancelFunc := context.WithTimeout(ctx, duration) + defer cancelFunc() + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } else { + go func(e Tasker) { + defer m.wait.Done() + err := e.Run(ctx) + if err != nil { + log.Printf("tasker.run err: %s\n", err.Error()) + } + }(e) + } + + e.SetPrev(context.Background(), e.GetNext(context.Background())) + e.SetNext(context.Background(), effective) + } +} +// StopTask stop all tasks +func (m *taskManager) StopTask() { + go func() { + m.stop <- true + }() +} + +// GracefulShutdown wait all task done +func (m *taskManager) GracefulShutdown() <-chan struct{} { + done := make(chan struct{}) + go func() { + m.stop <- true + m.wait.Wait() + close(done) + }() + return done } // AddTask add task with name -func AddTask(taskname string, t Tasker) { - taskLock.Lock() - defer taskLock.Unlock() - t.SetNext(time.Now().Local()) - AdminTaskList[taskname] = t - if isstart { - changed <- true +func (m *taskManager) AddTask(taskname string, t Tasker) { + isChanged := false + m.taskLock.Lock() + t.SetNext(nil, time.Now().Local()) + m.adminTaskList[taskname] = t + if m.started { + isChanged = true + } + m.taskLock.Unlock() + + if isChanged { + go func() { + m.changed <- true + }() } } // DeleteTask delete task with name -func DeleteTask(taskname string) { - taskLock.Lock() - defer taskLock.Unlock() - delete(AdminTaskList, taskname) - if isstart { - changed <- true +func (m *taskManager) DeleteTask(taskname string) { + isChanged := false + + m.taskLock.Lock() + delete(m.adminTaskList, taskname) + if m.started { + isChanged = true + } + m.taskLock.Unlock() + + if isChanged { + go func() { + m.changed <- true + }() } } +// ClearTask clear all tasks +func (m *taskManager) ClearTask() { + isChanged := false + + m.taskLock.Lock() + m.adminTaskList = make(map[string]Tasker) + if m.started { + isChanged = true + } + m.taskLock.Unlock() + + if isChanged { + go func() { + m.changed <- true + }() + } +} + +// GetAllTasks get all tasks +func (m *taskManager) GetAllTasks() []Tasker { + m.taskLock.RLock() + + l := make([]Tasker, 0, len(m.adminTaskList)) + + for _, t := range m.adminTaskList { + l = append(l, t) + } + m.taskLock.RUnlock() + + return l +} + // MapSorter sort map for tasker type MapSorter struct { Keys []string @@ -499,15 +702,17 @@ func (ms *MapSorter) Sort() { } func (ms *MapSorter) Len() int { return len(ms.Keys) } + func (ms *MapSorter) Less(i, j int) bool { - if ms.Vals[i].GetNext().IsZero() { + if ms.Vals[i].GetNext(context.Background()).IsZero() { return false } - if ms.Vals[j].GetNext().IsZero() { + if ms.Vals[j].GetNext(context.Background()).IsZero() { return true } - return ms.Vals[i].GetNext().Before(ms.Vals[j].GetNext()) + return ms.Vals[i].GetNext(context.Background()).Before(ms.Vals[j].GetNext(context.Background())) } + func (ms *MapSorter) Swap(i, j int) { ms.Vals[i], ms.Vals[j] = ms.Vals[j], ms.Vals[i] ms.Keys[i], ms.Keys[j] = ms.Keys[j], ms.Keys[i] @@ -524,9 +729,9 @@ func getField(field string, r bounds) uint64 { } // getRange returns the bits indicated by the given expression: -// number | number "-" number [ "/" number ] +// +// number | number "-" number [ "/" number ] func getRange(expr string, r bounds) uint64 { - var ( start, end, step uint rangeAndStep = strings.Split(expr, "/") @@ -623,7 +828,5 @@ func all(r bounds) uint64 { } func init() { - AdminTaskList = make(map[string]Tasker) - stop = make(chan bool) - changed = make(chan bool) + globalTaskManager = newTaskManager() } diff --git a/task/task_test.go b/task/task_test.go new file mode 100644 index 0000000000..98a3a05912 --- /dev/null +++ b/task/task_test.go @@ -0,0 +1,229 @@ +// Copyright 2014 beego Author. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + tk := NewTask("taska", "0/30 * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + return nil + }) + err := tk.Run(nil) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "0/30 * * * * *", tk.GetSpec(context.Background())) + m.AddTask("taska", tk) + m.StartTask() + time.Sleep(3 * time.Second) + assert.True(t, len(tk.GetStatus(context.Background())) == 0) + m.StopTask() +} + +func TestModifyTaskListAfterRunning(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + tk := NewTask("taskb", "0/30 * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + return nil + }) + err := tk.Run(nil) + if err != nil { + t.Fatal(err) + } + m.AddTask("taskb", tk) + m.StartTask() + go func() { + m.DeleteTask("taskb") + }() + go func() { + m.AddTask("taskb1", tk) + }() + + time.Sleep(3 * time.Second) + m.StopTask() +} + +func TestSpec(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + wg := &sync.WaitGroup{} + wg.Add(2) + tk1 := NewTask("tk1", "0 12 * * * *", func(ctx context.Context) error { fmt.Println("tk1"); return nil }) + tk2 := NewTask("tk2", "0,10,20 * * * * *", func(ctx context.Context) error { fmt.Println("tk2"); wg.Done(); return nil }) + tk3 := NewTask("tk3", "0 10 * * * *", func(ctx context.Context) error { fmt.Println("tk3"); wg.Done(); return nil }) + + m.AddTask("tk1", tk1) + m.AddTask("tk2", tk2) + m.AddTask("tk3", tk3) + m.StartTask() + defer m.StopTask() + + select { + case <-time.After(200 * time.Second): + t.FailNow() + case <-wait(wg): + } +} + +func TestTimeout(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + wg := &sync.WaitGroup{} + wg.Add(2) + once1, once2 := sync.Once{}, sync.Once{} + + tk1 := NewTask("tk1", "0/10 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + once1.Do(func() { + fmt.Println("tk1 done") + wg.Done() + }) + return errors.New("timeout") + default: + } + return nil + }, TimeoutOption(3*time.Second), + ) + + tk2 := NewTask("tk2", "0/11 * * ? * *", + func(ctx context.Context) error { + time.Sleep(4 * time.Second) + select { + case <-ctx.Done(): + return errors.New("timeout") + default: + once2.Do(func() { + fmt.Println("tk2 done") + wg.Done() + }) + } + return nil + }, + ) + + m.AddTask("tk1", tk1) + m.AddTask("tk2", tk2) + m.StartTask() + defer m.StopTask() + + select { + case <-time.After(19 * time.Second): + t.Error("TestTimeout failed") + case <-wait(wg): + } +} + +func TestTask_Run(t *testing.T) { + cnt := -1 + task := func(ctx context.Context) error { + cnt++ + fmt.Printf("Hello, world! %d \n", cnt) + return fmt.Errorf("Hello, world! %d", cnt) + } + tk := NewTask("taska", "0/30 * * * * *", task) + for i := 0; i < 200; i++ { + e := tk.Run(nil) + assert.NotNil(t, e) + } + + l := tk.Errlist + assert.Equal(t, 100, len(l)) + assert.Equal(t, "Hello, world! 100", l[0].errinfo) + assert.Equal(t, "Hello, world! 101", l[1].errinfo) +} + +func TestCrudTask(t *testing.T) { + m := newTaskManager() + m.AddTask("my-task1", NewTask("my-task1", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.AddTask("my-task2", NewTask("my-task2", "0/30 * * * * *", func(ctx context.Context) error { + return nil + })) + + m.DeleteTask("my-task1") + assert.Equal(t, 1, len(m.adminTaskList)) + + m.ClearTask() + assert.Equal(t, 0, len(m.adminTaskList)) +} + +func TestGracefulShutdown(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + waitDone := atomic.Value{} + waitDone.Store(false) + tk := NewTask("everySecond", "* * * * * *", func(ctx context.Context) error { + fmt.Println("hello world") + time.Sleep(2 * time.Second) + waitDone.Store(true) + return nil + }) + m.AddTask("taska", tk) + m.StartTask() + time.Sleep(1 * time.Second) + shutdown := m.GracefulShutdown() + assert.False(t, waitDone.Load().(bool)) + <-shutdown + assert.True(t, waitDone.Load().(bool)) +} + +func wait(wg *sync.WaitGroup) chan bool { + ch := make(chan bool) + go func() { + wg.Wait() + ch <- true + }() + return ch +} + +func TestGetAllTasks(t *testing.T) { + m := newTaskManager() + defer m.ClearTask() + + tk := NewTask("task1", "0/30 * * * * *", func(ctx context.Context) error { + return nil + }) + + tk2 := NewTask("task2", "0/40 * * * * *", func(ctx context.Context) error { + return nil + }) + + m.AddTask("task1", tk) + m.AddTask("task2", tk2) + + tasks := m.GetAllTasks() + total := len(tasks) + + assert.Equal(t, 2, total) +} diff --git a/test/Makefile b/test/Makefile new file mode 100644 index 0000000000..7483cf0588 --- /dev/null +++ b/test/Makefile @@ -0,0 +1,2 @@ +build_view: + $(GOPATH)/bin/go-bindata-assetfs -pkg testdata views/... diff --git a/testdata/bindata.go b/test/bindata.go similarity index 91% rename from testdata/bindata.go rename to test/bindata.go index beade103db..414625e36a 100644 --- a/testdata/bindata.go +++ b/test/bindata.go @@ -5,19 +5,20 @@ // views/index.tpl // DO NOT EDIT! -package testdata +package test import ( "bytes" "compress/gzip" "fmt" - "github.com/elazarl/go-bindata-assetfs" "io" "io/ioutil" "os" "path/filepath" "strings" "time" + + assetfs "github.com/elazarl/go-bindata-assetfs" ) func bindataRead(data []byte, name string) ([]byte, error) { @@ -55,18 +56,23 @@ type bindataFileInfo struct { func (fi bindataFileInfo) Name() string { return fi.name } + func (fi bindataFileInfo) Size() int64 { return fi.size } + func (fi bindataFileInfo) Mode() os.FileMode { return fi.mode } + func (fi bindataFileInfo) ModTime() time.Time { return fi.modTime } + func (fi bindataFileInfo) IsDir() bool { return false } + func (fi bindataFileInfo) Sys() interface{} { return nil } @@ -86,7 +92,7 @@ func viewsBlocksBlockTpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + info := bindataFileInfo{name: "views/blocks/block.tpl", size: 50, mode: os.FileMode(0o664), modTime: time.Unix(1541431067, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -106,7 +112,7 @@ func viewsHeaderTpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(436), modTime: time.Unix(1541431067, 0)} + info := bindataFileInfo{name: "views/header.tpl", size: 52, mode: os.FileMode(0o664), modTime: time.Unix(1541431067, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -126,7 +132,7 @@ func viewsIndexTpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(436), modTime: time.Unix(1541434906, 0)} + info := bindataFileInfo{name: "views/index.tpl", size: 255, mode: os.FileMode(0o664), modTime: time.Unix(1541434906, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -192,11 +198,13 @@ var _bindata = map[string]func() (*asset, error){ // directory embedded in the file by go-bindata. // For example if you run go-bindata on data/... and data contains the // following hierarchy: -// data/ -// foo.txt -// img/ -// a.png -// b.png +// +// data/ +// foo.txt +// img/ +// a.png +// b.png +// // then AssetDir("data") would return []string{"foo.txt", "img"} // AssetDir("data/img") would return []string{"a.png", "b.png"} // AssetDir("foo.txt") and AssetDir("notexist") would return an error @@ -229,12 +237,12 @@ type bintree struct { } var _bintree = &bintree{nil, map[string]*bintree{ - "views": &bintree{nil, map[string]*bintree{ - "blocks": &bintree{nil, map[string]*bintree{ - "block.tpl": &bintree{viewsBlocksBlockTpl, map[string]*bintree{}}, + "views": {nil, map[string]*bintree{ + "blocks": {nil, map[string]*bintree{ + "block.tpl": {viewsBlocksBlockTpl, map[string]*bintree{}}, }}, - "header.tpl": &bintree{viewsHeaderTpl, map[string]*bintree{}}, - "index.tpl": &bintree{viewsIndexTpl, map[string]*bintree{}}, + "header.tpl": {viewsHeaderTpl, map[string]*bintree{}}, + "index.tpl": {viewsIndexTpl, map[string]*bintree{}}, }}, }} @@ -248,7 +256,7 @@ func RestoreAsset(dir, name string) error { if err != nil { return err } - err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0755)) + err = os.MkdirAll(_filePath(dir, filepath.Dir(name)), os.FileMode(0o755)) if err != nil { return err } @@ -286,9 +294,7 @@ func _filePath(dir, name string) string { } func assetFS() *assetfs.AssetFS { - assetInfo := func(path string) (os.FileInfo, error) { - return os.Stat(path) - } + assetInfo := os.Stat for k := range _bintree.Children { return &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, AssetInfo: assetInfo, Prefix: k} } diff --git a/testdata/views/blocks/block.tpl b/test/views/blocks/block.tpl similarity index 84% rename from testdata/views/blocks/block.tpl rename to test/views/blocks/block.tpl index 2a9c57fce3..bd4cf96000 100644 --- a/testdata/views/blocks/block.tpl +++ b/test/views/blocks/block.tpl @@ -1,3 +1,3 @@ {{define "block"}}

Hello, blocks!

-{{end}} \ No newline at end of file +{{end}} diff --git a/testdata/views/header.tpl b/test/views/header.tpl similarity index 84% rename from testdata/views/header.tpl rename to test/views/header.tpl index 041fa403d7..0d36989e15 100644 --- a/testdata/views/header.tpl +++ b/test/views/header.tpl @@ -1,3 +1,3 @@ {{define "header"}}

Hello, astaxie!

-{{end}} \ No newline at end of file +{{end}} diff --git a/testdata/views/index.tpl b/test/views/index.tpl similarity index 100% rename from testdata/views/index.tpl rename to test/views/index.tpl diff --git a/testdata/Makefile b/testdata/Makefile deleted file mode 100644 index e80e8238a5..0000000000 --- a/testdata/Makefile +++ /dev/null @@ -1,2 +0,0 @@ -build_view: - $(GOPATH)/bin/go-bindata-assetfs -pkg testdata views/... \ No newline at end of file diff --git a/toolbox/task_test.go b/toolbox/task_test.go deleted file mode 100644 index 596bc9c5b0..0000000000 --- a/toolbox/task_test.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package toolbox - -import ( - "fmt" - "sync" - "testing" - "time" -) - -func TestParse(t *testing.T) { - tk := NewTask("taska", "0/30 * * * * *", func() error { fmt.Println("hello world"); return nil }) - err := tk.Run() - if err != nil { - t.Fatal(err) - } - AddTask("taska", tk) - StartTask() - time.Sleep(6 * time.Second) - StopTask() -} - -func TestSpec(t *testing.T) { - wg := &sync.WaitGroup{} - wg.Add(2) - tk1 := NewTask("tk1", "0 12 * * * *", func() error { fmt.Println("tk1"); return nil }) - tk2 := NewTask("tk2", "0,10,20 * * * * *", func() error { fmt.Println("tk2"); wg.Done(); return nil }) - tk3 := NewTask("tk3", "0 10 * * * *", func() error { fmt.Println("tk3"); wg.Done(); return nil }) - - AddTask("tk1", tk1) - AddTask("tk2", tk2) - AddTask("tk3", tk3) - StartTask() - defer StopTask() - - select { - case <-time.After(200 * time.Second): - t.FailNow() - case <-wait(wg): - } -} - -func wait(wg *sync.WaitGroup) chan bool { - ch := make(chan bool) - go func() { - wg.Wait() - ch <- true - }() - return ch -} diff --git a/tree_test.go b/tree_test.go deleted file mode 100644 index d412a34812..0000000000 --- a/tree_test.go +++ /dev/null @@ -1,306 +0,0 @@ -// Copyright 2014 beego Author. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package beego - -import ( - "strings" - "testing" - - "github.com/astaxie/beego/context" -) - -type testinfo struct { - url string - requesturl string - params map[string]string -} - -var routers []testinfo - -func init() { - routers = make([]testinfo, 0) - routers = append(routers, testinfo{"/topic/?:auth:int", "/topic", nil}) - routers = append(routers, testinfo{"/topic/?:auth:int", "/topic/123", map[string]string{":auth": "123"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1", map[string]string{":id": "1"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth", "/topic/1/2", map[string]string{":id": "1", ":auth": "2"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1", map[string]string{":id": "1"}}) - routers = append(routers, testinfo{"/topic/:id/?:auth:int", "/topic/1/123", map[string]string{":id": "1", ":auth": "123"}}) - routers = append(routers, testinfo{"/:id", "/123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/hello/?:id", "/hello", map[string]string{":id": ""}}) - routers = append(routers, testinfo{"/", "/", nil}) - routers = append(routers, testinfo{"/customer/login", "/customer/login", nil}) - routers = append(routers, testinfo{"/customer/login", "/customer/login.json", map[string]string{":ext": "json"}}) - routers = append(routers, testinfo{"/*", "/http://customer/123/", map[string]string{":splat": "http://customer/123/"}}) - routers = append(routers, testinfo{"/*", "/customer/2009/12/11", map[string]string{":splat": "customer/2009/12/11"}}) - routers = append(routers, testinfo{"/aa/*/bb", "/aa/2009/bb", map[string]string{":splat": "2009"}}) - routers = append(routers, testinfo{"/cc/*/dd", "/cc/2009/11/dd", map[string]string{":splat": "2009/11"}}) - routers = append(routers, testinfo{"/cc/:id/*", "/cc/2009/11/dd", map[string]string{":id": "2009", ":splat": "11/dd"}}) - routers = append(routers, testinfo{"/ee/:year/*/ff", "/ee/2009/11/ff", map[string]string{":year": "2009", ":splat": "11"}}) - routers = append(routers, testinfo{"/thumbnail/:size/uploads/*", - "/thumbnail/100x100/uploads/items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg", - map[string]string{":size": "100x100", ":splat": "items/2014/04/20/dPRCdChkUd651t1Hvs18.jpg"}}) - routers = append(routers, testinfo{"/*.*", "/nice/api.json", map[string]string{":path": "nice/api", ":ext": "json"}}) - routers = append(routers, testinfo{"/:name/*.*", "/nice/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) - routers = append(routers, testinfo{"/:name/test/*.*", "/nice/test/api.json", map[string]string{":name": "nice", ":path": "api", ":ext": "json"}}) - routers = append(routers, testinfo{"/dl/:width:int/:height:int/*.*", - "/dl/48/48/05ac66d9bda00a3acf948c43e306fc9a.jpg", - map[string]string{":width": "48", ":height": "48", ":ext": "jpg", ":path": "05ac66d9bda00a3acf948c43e306fc9a"}}) - routers = append(routers, testinfo{"/v1/shop/:id:int", "/v1/shop/123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(a)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(b)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id\\((a|b|c)\\)", "/v1/shop/123(c)", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/:year:int/:month:int/:id/:endid", "/1111/111/aaa/aaa", map[string]string{":year": "1111", ":month": "111", ":id": "aaa", ":endid": "aaa"}}) - routers = append(routers, testinfo{"/v1/shop/:id/:name", "/v1/shop/123/nike", map[string]string{":id": "123", ":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id/account", "/v1/shop/123/account", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:name:string", "/v1/shop/nike", map[string]string{":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)", "/v1/shop//123", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/:id([0-9]+)_:name", "/v1/shop/123_nike", map[string]string{":id": "123", ":name": "nike"}}) - routers = append(routers, testinfo{"/v1/shop/:id(.+)_cms.html", "/v1/shop/123_cms.html", map[string]string{":id": "123"}}) - routers = append(routers, testinfo{"/v1/shop/cms_:id(.+)_:page(.+).html", "/v1/shop/cms_123_1.html", map[string]string{":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v/cms/aaa_:id(.+)_:page(.+).html", "/v1/2/cms/aaa_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v/cms_:id(.+)_:page(.+).html", "/v1/2/cms_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/v1/:v(.+)_cms/ttt_:id(.+)_:page(.+).html", "/v1/2_cms/ttt_123_1.html", map[string]string{":v": "2", ":id": "123", ":page": "1"}}) - routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members", map[string]string{":pid": "1"}}) - routers = append(routers, testinfo{"/api/projects/:pid/members/?:mid", "/api/projects/1/members/2", map[string]string{":pid": "1", ":mid": "2"}}) -} - -func TestTreeRouters(t *testing.T) { - for _, r := range routers { - tr := NewTree() - tr.AddRouter(r.url, "astaxie") - ctx := context.NewContext() - obj := tr.Match(r.requesturl, ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal(r.url+" can't get obj, Expect ", r.requesturl) - } - if r.params != nil { - for k, v := range r.params { - if vv := ctx.Input.Param(k); vv != v { - t.Fatal("The Rule: " + r.url + "\nThe RequestURL:" + r.requesturl + "\nThe Key is " + k + ", The Value should be: " + v + ", but get: " + vv) - } else if vv == "" && v != "" { - t.Fatal(r.url + " " + r.requesturl + " get param empty:" + k) - } - } - } - } -} - -func TestStaticPath(t *testing.T) { - tr := NewTree() - tr.AddRouter("/topic/:id", "wildcard") - tr.AddRouter("/topic", "static") - ctx := context.NewContext() - obj := tr.Match("/topic", ctx) - if obj == nil || obj.(string) != "static" { - t.Fatal("/topic is a static route") - } - obj = tr.Match("/topic/1", ctx) - if obj == nil || obj.(string) != "wildcard" { - t.Fatal("/topic/1 is a wildcard route") - } -} - -func TestAddTree(t *testing.T) { - tr := NewTree() - tr.AddRouter("/shop/:id/account", "astaxie") - tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") - t1 := NewTree() - t1.AddTree("/v1/zl", tr) - ctx := context.NewContext() - obj := t1.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/zl/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" { - t.Fatal("get :id param error") - } - ctx.Input.Reset(ctx) - obj = t1.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/zl//shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" { - t.Fatal("get :sd :id :page param error") - } - - t2 := NewTree() - t2.AddTree("/v1/:shopid", tr) - ctx.Input.Reset(ctx) - obj = t2.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/:shopid/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":shopid") != "zl" { - t.Fatal("get :id :shopid param error") - } - ctx.Input.Reset(ctx) - obj = t2.Match("/v1/zl/shop/123/ttt_1_12.html", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/v1/:shopid/shop/:sd/ttt_:id(.+)_:page(.+).html can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get :shopid param error") - } - if ctx.Input.Param(":sd") != "123" || ctx.Input.Param(":id") != "1" || ctx.Input.Param(":page") != "12" || ctx.Input.Param(":shopid") != "zl" { - t.Fatal("get :sd :id :page :shopid param error") - } -} - -func TestAddTree2(t *testing.T) { - tr := NewTree() - tr.AddRouter("/shop/:id/account", "astaxie") - tr.AddRouter("/shop/:sd/ttt_:id(.+)_:page(.+).html", "astaxie") - t3 := NewTree() - t3.AddTree("/:version(v1|v2)/:prefix", tr) - ctx := context.NewContext() - obj := t3.Match("/v1/zl/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:version(v1|v2)/:prefix/shop/:id/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":id") != "123" || ctx.Input.Param(":prefix") != "zl" || ctx.Input.Param(":version") != "v1" { - t.Fatal("get :id :prefix :version param error") - } -} - -func TestAddTree3(t *testing.T) { - tr := NewTree() - tr.AddRouter("/create", "astaxie") - tr.AddRouter("/shop/:sd/account", "astaxie") - t3 := NewTree() - t3.AddTree("/table/:num", tr) - ctx := context.NewContext() - obj := t3.Match("/table/123/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/table/:num/shop/:sd/account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":num") != "123" || ctx.Input.Param(":sd") != "123" { - t.Fatal("get :num :sd param error") - } - ctx.Input.Reset(ctx) - obj = t3.Match("/table/123/create", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/table/:num/create can't get obj ") - } -} - -func TestAddTree4(t *testing.T) { - tr := NewTree() - tr.AddRouter("/create", "astaxie") - tr.AddRouter("/shop/:sd/:account", "astaxie") - t4 := NewTree() - t4.AddTree("/:info:int/:num/:id", tr) - ctx := context.NewContext() - obj := t4.Match("/12/123/456/shop/123/account", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:info:int/:num/:id/shop/:sd/:account can't get obj ") - } - if ctx.Input.ParamsLen() == 0 { - t.Fatal("get param error") - } - if ctx.Input.Param(":info") != "12" || ctx.Input.Param(":num") != "123" || - ctx.Input.Param(":id") != "456" || ctx.Input.Param(":sd") != "123" || - ctx.Input.Param(":account") != "account" { - t.Fatal("get :info :num :id :sd :account param error") - } - ctx.Input.Reset(ctx) - obj = t4.Match("/12/123/456/create", ctx) - if obj == nil || obj.(string) != "astaxie" { - t.Fatal("/:info:int/:num/:id/create can't get obj ") - } -} - -// Test for issue #1595 -func TestAddTree5(t *testing.T) { - tr := NewTree() - tr.AddRouter("/v1/shop/:id", "shopdetail") - tr.AddRouter("/v1/shop/", "shophome") - ctx := context.NewContext() - obj := tr.Match("/v1/shop/", ctx) - if obj == nil || obj.(string) != "shophome" { - t.Fatal("url /v1/shop/ need match router /v1/shop/ ") - } -} - -func TestSplitPath(t *testing.T) { - a := splitPath("") - if len(a) != 0 { - t.Fatal("/ should retrun []") - } - a = splitPath("/") - if len(a) != 0 { - t.Fatal("/ should retrun []") - } - a = splitPath("/admin") - if len(a) != 1 || a[0] != "admin" { - t.Fatal("/admin should retrun [admin]") - } - a = splitPath("/admin/") - if len(a) != 1 || a[0] != "admin" { - t.Fatal("/admin/ should retrun [admin]") - } - a = splitPath("/admin/users") - if len(a) != 2 || a[0] != "admin" || a[1] != "users" { - t.Fatal("/admin should retrun [admin users]") - } - a = splitPath("/admin/:id:int") - if len(a) != 2 || a[0] != "admin" || a[1] != ":id:int" { - t.Fatal("/admin should retrun [admin :id:int]") - } -} - -func TestSplitSegment(t *testing.T) { - - items := map[string]struct { - isReg bool - params []string - regStr string - }{ - "admin": {false, nil, ""}, - "*": {true, []string{":splat"}, ""}, - "*.*": {true, []string{".", ":path", ":ext"}, ""}, - ":id": {true, []string{":id"}, ""}, - "?:id": {true, []string{":", ":id"}, ""}, - ":id:int": {true, []string{":id"}, "([0-9]+)"}, - ":name:string": {true, []string{":name"}, `([\w]+)`}, - ":id([0-9]+)": {true, []string{":id"}, `([0-9]+)`}, - ":id([0-9]+)_:name": {true, []string{":id", ":name"}, `([0-9]+)_(.+)`}, - ":id(.+)_cms.html": {true, []string{":id"}, `(.+)_cms.html`}, - "cms_:id(.+)_:page(.+).html": {true, []string{":id", ":page"}, `cms_(.+)_(.+).html`}, - `:app(a|b|c)`: {true, []string{":app"}, `(a|b|c)`}, - `:app\((a|b|c)\)`: {true, []string{":app"}, `(.+)\((a|b|c)\)`}, - } - - for pattern, v := range items { - b, w, r := splitSegment(pattern) - if b != v.isReg || r != v.regStr || strings.Join(w, ",") != strings.Join(v.params, ",") { - t.Fatalf("%s should return %t,%s,%q, got %t,%s,%q", pattern, v.isReg, v.params, v.regStr, b, w, r) - } - } -} diff --git a/utils/pagination/doc.go b/utils/pagination/doc.go deleted file mode 100644 index 9abc6d782c..0000000000 --- a/utils/pagination/doc.go +++ /dev/null @@ -1,58 +0,0 @@ -/* -Package pagination provides utilities to setup a paginator within the -context of a http request. - -Usage - -In your beego.Controller: - - package controllers - - import "github.com/astaxie/beego/utils/pagination" - - type PostsController struct { - beego.Controller - } - - func (this *PostsController) ListAllPosts() { - // sets this.Data["paginator"] with the current offset (from the url query param) - postsPerPage := 20 - paginator := pagination.SetPaginator(this.Ctx, postsPerPage, CountPosts()) - - // fetch the next 20 posts - this.Data["posts"] = ListPostsByOffsetAndLimit(paginator.Offset(), postsPerPage) - } - - -In your view templates: - - {{if .paginator.HasPages}} - - {{end}} - -See also - -http://beego.me/docs/mvc/view/page.md - -*/ -package pagination diff --git a/utils/utils.go b/utils/utils.go deleted file mode 100644 index ed88578734..0000000000 --- a/utils/utils.go +++ /dev/null @@ -1,30 +0,0 @@ -package utils - -import ( - "os" - "path/filepath" - "runtime" - "strings" -) - -// GetGOPATHs returns all paths in GOPATH variable. -func GetGOPATHs() []string { - gopath := os.Getenv("GOPATH") - if gopath == "" && strings.Compare(runtime.Version(), "go1.8") >= 0 { - gopath = defaultGOPATH() - } - return filepath.SplitList(gopath) -} - -func defaultGOPATH() string { - env := "HOME" - if runtime.GOOS == "windows" { - env = "USERPROFILE" - } else if runtime.GOOS == "plan9" { - env = "home" - } - if home := os.Getenv(env); home != "" { - return filepath.Join(home, "go") - } - return "" -} diff --git a/vendor/golang.org/x/crypto/LICENSE b/vendor/golang.org/x/crypto/LICENSE deleted file mode 100644 index 6a66aea5ea..0000000000 --- a/vendor/golang.org/x/crypto/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/golang.org/x/crypto/PATENTS b/vendor/golang.org/x/crypto/PATENTS deleted file mode 100644 index 733099041f..0000000000 --- a/vendor/golang.org/x/crypto/PATENTS +++ /dev/null @@ -1,22 +0,0 @@ -Additional IP Rights Grant (Patents) - -"This implementation" means the copyrightable works distributed by -Google as part of the Go project. - -Google hereby grants to You a perpetual, worldwide, non-exclusive, -no-charge, royalty-free, irrevocable (except as stated in this section) -patent license to make, have made, use, offer to sell, sell, import, -transfer and otherwise run, modify and propagate the contents of this -implementation of Go, where such license applies only to those patent -claims, both currently owned or controlled by Google and acquired in -the future, licensable by Google that are necessarily infringed by this -implementation of Go. This grant does not include claims that would be -infringed only as a consequence of further modification of this -implementation. If you or your agent or exclusive licensee institute or -order or agree to the institution of patent litigation against any -entity (including a cross-claim or counterclaim in a lawsuit) alleging -that this implementation of Go or any code incorporated within this -implementation of Go constitutes direct or contributory patent -infringement, or inducement of patent infringement, then any patent -rights granted to you under this License for this implementation of Go -shall terminate as of the date such litigation is filed. diff --git a/vendor/golang.org/x/crypto/acme/acme.go b/vendor/golang.org/x/crypto/acme/acme.go deleted file mode 100644 index ece9113f50..0000000000 --- a/vendor/golang.org/x/crypto/acme/acme.go +++ /dev/null @@ -1,921 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package acme provides an implementation of the -// Automatic Certificate Management Environment (ACME) spec. -// See https://tools.ietf.org/html/draft-ietf-acme-acme-02 for details. -// -// Most common scenarios will want to use autocert subdirectory instead, -// which provides automatic access to certificates from Let's Encrypt -// and any other ACME-based CA. -// -// This package is a work in progress and makes no API stability promises. -package acme - -import ( - "context" - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "encoding/base64" - "encoding/hex" - "encoding/json" - "encoding/pem" - "errors" - "fmt" - "io" - "io/ioutil" - "math/big" - "net/http" - "strings" - "sync" - "time" -) - -const ( - // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA. - LetsEncryptURL = "https://acme-v01.api.letsencrypt.org/directory" - - // ALPNProto is the ALPN protocol name used by a CA server when validating - // tls-alpn-01 challenges. - // - // Package users must ensure their servers can negotiate the ACME ALPN - // in order for tls-alpn-01 challenge verifications to succeed. - ALPNProto = "acme-tls/1" -) - -// idPeACMEIdentifierV1 is the OID for the ACME extension for the TLS-ALPN challenge. -var idPeACMEIdentifierV1 = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} - -const ( - maxChainLen = 5 // max depth and breadth of a certificate chain - maxCertSize = 1 << 20 // max size of a certificate, in bytes - - // Max number of collected nonces kept in memory. - // Expect usual peak of 1 or 2. - maxNonces = 100 -) - -// Client is an ACME client. -// The only required field is Key. An example of creating a client with a new key -// is as follows: -// -// key, err := rsa.GenerateKey(rand.Reader, 2048) -// if err != nil { -// log.Fatal(err) -// } -// client := &Client{Key: key} -// -type Client struct { - // Key is the account key used to register with a CA and sign requests. - // Key.Public() must return a *rsa.PublicKey or *ecdsa.PublicKey. - Key crypto.Signer - - // HTTPClient optionally specifies an HTTP client to use - // instead of http.DefaultClient. - HTTPClient *http.Client - - // DirectoryURL points to the CA directory endpoint. - // If empty, LetsEncryptURL is used. - // Mutating this value after a successful call of Client's Discover method - // will have no effect. - DirectoryURL string - - // RetryBackoff computes the duration after which the nth retry of a failed request - // should occur. The value of n for the first call on failure is 1. - // The values of r and resp are the request and response of the last failed attempt. - // If the returned value is negative or zero, no more retries are done and an error - // is returned to the caller of the original method. - // - // Requests which result in a 4xx client error are not retried, - // except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests. - // - // If RetryBackoff is nil, a truncated exponential backoff algorithm - // with the ceiling of 10 seconds is used, where each subsequent retry n - // is done after either ("Retry-After" + jitter) or (2^n seconds + jitter), - // preferring the former if "Retry-After" header is found in the resp. - // The jitter is a random value up to 1 second. - RetryBackoff func(n int, r *http.Request, resp *http.Response) time.Duration - - dirMu sync.Mutex // guards writes to dir - dir *Directory // cached result of Client's Discover method - - noncesMu sync.Mutex - nonces map[string]struct{} // nonces collected from previous responses -} - -// Discover performs ACME server discovery using c.DirectoryURL. -// -// It caches successful result. So, subsequent calls will not result in -// a network round-trip. This also means mutating c.DirectoryURL after successful call -// of this method will have no effect. -func (c *Client) Discover(ctx context.Context) (Directory, error) { - c.dirMu.Lock() - defer c.dirMu.Unlock() - if c.dir != nil { - return *c.dir, nil - } - - dirURL := c.DirectoryURL - if dirURL == "" { - dirURL = LetsEncryptURL - } - res, err := c.get(ctx, dirURL, wantStatus(http.StatusOK)) - if err != nil { - return Directory{}, err - } - defer res.Body.Close() - c.addNonce(res.Header) - - var v struct { - Reg string `json:"new-reg"` - Authz string `json:"new-authz"` - Cert string `json:"new-cert"` - Revoke string `json:"revoke-cert"` - Meta struct { - Terms string `json:"terms-of-service"` - Website string `json:"website"` - CAA []string `json:"caa-identities"` - } - } - if err := json.NewDecoder(res.Body).Decode(&v); err != nil { - return Directory{}, err - } - c.dir = &Directory{ - RegURL: v.Reg, - AuthzURL: v.Authz, - CertURL: v.Cert, - RevokeURL: v.Revoke, - Terms: v.Meta.Terms, - Website: v.Meta.Website, - CAA: v.Meta.CAA, - } - return *c.dir, nil -} - -// CreateCert requests a new certificate using the Certificate Signing Request csr encoded in DER format. -// The exp argument indicates the desired certificate validity duration. CA may issue a certificate -// with a different duration. -// If the bundle argument is true, the returned value will also contain the CA (issuer) certificate chain. -// -// In the case where CA server does not provide the issued certificate in the response, -// CreateCert will poll certURL using c.FetchCert, which will result in additional round-trips. -// In such a scenario, the caller can cancel the polling with ctx. -// -// CreateCert returns an error if the CA's response or chain was unreasonably large. -// Callers are encouraged to parse the returned value to ensure the certificate is valid and has the expected features. -func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, bundle bool) (der [][]byte, certURL string, err error) { - if _, err := c.Discover(ctx); err != nil { - return nil, "", err - } - - req := struct { - Resource string `json:"resource"` - CSR string `json:"csr"` - NotBefore string `json:"notBefore,omitempty"` - NotAfter string `json:"notAfter,omitempty"` - }{ - Resource: "new-cert", - CSR: base64.RawURLEncoding.EncodeToString(csr), - } - now := timeNow() - req.NotBefore = now.Format(time.RFC3339) - if exp > 0 { - req.NotAfter = now.Add(exp).Format(time.RFC3339) - } - - res, err := c.post(ctx, c.Key, c.dir.CertURL, req, wantStatus(http.StatusCreated)) - if err != nil { - return nil, "", err - } - defer res.Body.Close() - - curl := res.Header.Get("Location") // cert permanent URL - if res.ContentLength == 0 { - // no cert in the body; poll until we get it - cert, err := c.FetchCert(ctx, curl, bundle) - return cert, curl, err - } - // slurp issued cert and CA chain, if requested - cert, err := c.responseCert(ctx, res, bundle) - return cert, curl, err -} - -// FetchCert retrieves already issued certificate from the given url, in DER format. -// It retries the request until the certificate is successfully retrieved, -// context is cancelled by the caller or an error response is received. -// -// The returned value will also contain the CA (issuer) certificate if the bundle argument is true. -// -// FetchCert returns an error if the CA's response or chain was unreasonably large. -// Callers are encouraged to parse the returned value to ensure the certificate is valid -// and has expected features. -func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) { - res, err := c.get(ctx, url, wantStatus(http.StatusOK)) - if err != nil { - return nil, err - } - return c.responseCert(ctx, res, bundle) -} - -// RevokeCert revokes a previously issued certificate cert, provided in DER format. -// -// The key argument, used to sign the request, must be authorized -// to revoke the certificate. It's up to the CA to decide which keys are authorized. -// For instance, the key pair of the certificate may be authorized. -// If the key is nil, c.Key is used instead. -func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error { - if _, err := c.Discover(ctx); err != nil { - return err - } - - body := &struct { - Resource string `json:"resource"` - Cert string `json:"certificate"` - Reason int `json:"reason"` - }{ - Resource: "revoke-cert", - Cert: base64.RawURLEncoding.EncodeToString(cert), - Reason: int(reason), - } - if key == nil { - key = c.Key - } - res, err := c.post(ctx, key, c.dir.RevokeURL, body, wantStatus(http.StatusOK)) - if err != nil { - return err - } - defer res.Body.Close() - return nil -} - -// AcceptTOS always returns true to indicate the acceptance of a CA's Terms of Service -// during account registration. See Register method of Client for more details. -func AcceptTOS(tosURL string) bool { return true } - -// Register creates a new account registration by following the "new-reg" flow. -// It returns the registered account. The account is not modified. -// -// The registration may require the caller to agree to the CA's Terms of Service (TOS). -// If so, and the account has not indicated the acceptance of the terms (see Account for details), -// Register calls prompt with a TOS URL provided by the CA. Prompt should report -// whether the caller agrees to the terms. To always accept the terms, the caller can use AcceptTOS. -func (c *Client) Register(ctx context.Context, a *Account, prompt func(tosURL string) bool) (*Account, error) { - if _, err := c.Discover(ctx); err != nil { - return nil, err - } - - var err error - if a, err = c.doReg(ctx, c.dir.RegURL, "new-reg", a); err != nil { - return nil, err - } - var accept bool - if a.CurrentTerms != "" && a.CurrentTerms != a.AgreedTerms { - accept = prompt(a.CurrentTerms) - } - if accept { - a.AgreedTerms = a.CurrentTerms - a, err = c.UpdateReg(ctx, a) - } - return a, err -} - -// GetReg retrieves an existing registration. -// The url argument is an Account URI. -func (c *Client) GetReg(ctx context.Context, url string) (*Account, error) { - a, err := c.doReg(ctx, url, "reg", nil) - if err != nil { - return nil, err - } - a.URI = url - return a, nil -} - -// UpdateReg updates an existing registration. -// It returns an updated account copy. The provided account is not modified. -func (c *Client) UpdateReg(ctx context.Context, a *Account) (*Account, error) { - uri := a.URI - a, err := c.doReg(ctx, uri, "reg", a) - if err != nil { - return nil, err - } - a.URI = uri - return a, nil -} - -// Authorize performs the initial step in an authorization flow. -// The caller will then need to choose from and perform a set of returned -// challenges using c.Accept in order to successfully complete authorization. -// -// If an authorization has been previously granted, the CA may return -// a valid authorization (Authorization.Status is StatusValid). If so, the caller -// need not fulfill any challenge and can proceed to requesting a certificate. -func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, error) { - if _, err := c.Discover(ctx); err != nil { - return nil, err - } - - type authzID struct { - Type string `json:"type"` - Value string `json:"value"` - } - req := struct { - Resource string `json:"resource"` - Identifier authzID `json:"identifier"` - }{ - Resource: "new-authz", - Identifier: authzID{Type: "dns", Value: domain}, - } - res, err := c.post(ctx, c.Key, c.dir.AuthzURL, req, wantStatus(http.StatusCreated)) - if err != nil { - return nil, err - } - defer res.Body.Close() - - var v wireAuthz - if err := json.NewDecoder(res.Body).Decode(&v); err != nil { - return nil, fmt.Errorf("acme: invalid response: %v", err) - } - if v.Status != StatusPending && v.Status != StatusValid { - return nil, fmt.Errorf("acme: unexpected status: %s", v.Status) - } - return v.authorization(res.Header.Get("Location")), nil -} - -// GetAuthorization retrieves an authorization identified by the given URL. -// -// If a caller needs to poll an authorization until its status is final, -// see the WaitAuthorization method. -func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) { - res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) - if err != nil { - return nil, err - } - defer res.Body.Close() - var v wireAuthz - if err := json.NewDecoder(res.Body).Decode(&v); err != nil { - return nil, fmt.Errorf("acme: invalid response: %v", err) - } - return v.authorization(url), nil -} - -// RevokeAuthorization relinquishes an existing authorization identified -// by the given URL. -// The url argument is an Authorization.URI value. -// -// If successful, the caller will be required to obtain a new authorization -// using the Authorize method before being able to request a new certificate -// for the domain associated with the authorization. -// -// It does not revoke existing certificates. -func (c *Client) RevokeAuthorization(ctx context.Context, url string) error { - req := struct { - Resource string `json:"resource"` - Status string `json:"status"` - Delete bool `json:"delete"` - }{ - Resource: "authz", - Status: "deactivated", - Delete: true, - } - res, err := c.post(ctx, c.Key, url, req, wantStatus(http.StatusOK)) - if err != nil { - return err - } - defer res.Body.Close() - return nil -} - -// WaitAuthorization polls an authorization at the given URL -// until it is in one of the final states, StatusValid or StatusInvalid, -// the ACME CA responded with a 4xx error code, or the context is done. -// -// It returns a non-nil Authorization only if its Status is StatusValid. -// In all other cases WaitAuthorization returns an error. -// If the Status is StatusInvalid, the returned error is of type *AuthorizationError. -func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) { - for { - res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) - if err != nil { - return nil, err - } - - var raw wireAuthz - err = json.NewDecoder(res.Body).Decode(&raw) - res.Body.Close() - switch { - case err != nil: - // Skip and retry. - case raw.Status == StatusValid: - return raw.authorization(url), nil - case raw.Status == StatusInvalid: - return nil, raw.error(url) - } - - // Exponential backoff is implemented in c.get above. - // This is just to prevent continuously hitting the CA - // while waiting for a final authorization status. - d := retryAfter(res.Header.Get("Retry-After")) - if d == 0 { - // Given that the fastest challenges TLS-SNI and HTTP-01 - // require a CA to make at least 1 network round trip - // and most likely persist a challenge state, - // this default delay seems reasonable. - d = time.Second - } - t := time.NewTimer(d) - select { - case <-ctx.Done(): - t.Stop() - return nil, ctx.Err() - case <-t.C: - // Retry. - } - } -} - -// GetChallenge retrieves the current status of an challenge. -// -// A client typically polls a challenge status using this method. -func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) { - res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) - if err != nil { - return nil, err - } - defer res.Body.Close() - v := wireChallenge{URI: url} - if err := json.NewDecoder(res.Body).Decode(&v); err != nil { - return nil, fmt.Errorf("acme: invalid response: %v", err) - } - return v.challenge(), nil -} - -// Accept informs the server that the client accepts one of its challenges -// previously obtained with c.Authorize. -// -// The server will then perform the validation asynchronously. -func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error) { - auth, err := keyAuth(c.Key.Public(), chal.Token) - if err != nil { - return nil, err - } - - req := struct { - Resource string `json:"resource"` - Type string `json:"type"` - Auth string `json:"keyAuthorization"` - }{ - Resource: "challenge", - Type: chal.Type, - Auth: auth, - } - res, err := c.post(ctx, c.Key, chal.URI, req, wantStatus( - http.StatusOK, // according to the spec - http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md) - )) - if err != nil { - return nil, err - } - defer res.Body.Close() - - var v wireChallenge - if err := json.NewDecoder(res.Body).Decode(&v); err != nil { - return nil, fmt.Errorf("acme: invalid response: %v", err) - } - return v.challenge(), nil -} - -// DNS01ChallengeRecord returns a DNS record value for a dns-01 challenge response. -// A TXT record containing the returned value must be provisioned under -// "_acme-challenge" name of the domain being validated. -// -// The token argument is a Challenge.Token value. -func (c *Client) DNS01ChallengeRecord(token string) (string, error) { - ka, err := keyAuth(c.Key.Public(), token) - if err != nil { - return "", err - } - b := sha256.Sum256([]byte(ka)) - return base64.RawURLEncoding.EncodeToString(b[:]), nil -} - -// HTTP01ChallengeResponse returns the response for an http-01 challenge. -// Servers should respond with the value to HTTP requests at the URL path -// provided by HTTP01ChallengePath to validate the challenge and prove control -// over a domain name. -// -// The token argument is a Challenge.Token value. -func (c *Client) HTTP01ChallengeResponse(token string) (string, error) { - return keyAuth(c.Key.Public(), token) -} - -// HTTP01ChallengePath returns the URL path at which the response for an http-01 challenge -// should be provided by the servers. -// The response value can be obtained with HTTP01ChallengeResponse. -// -// The token argument is a Challenge.Token value. -func (c *Client) HTTP01ChallengePath(token string) string { - return "/.well-known/acme-challenge/" + token -} - -// TLSSNI01ChallengeCert creates a certificate for TLS-SNI-01 challenge response. -// Servers can present the certificate to validate the challenge and prove control -// over a domain name. -// -// The implementation is incomplete in that the returned value is a single certificate, -// computed only for Z0 of the key authorization. ACME CAs are expected to update -// their implementations to use the newer version, TLS-SNI-02. -// For more details on TLS-SNI-01 see https://tools.ietf.org/html/draft-ietf-acme-acme-01#section-7.3. -// -// The token argument is a Challenge.Token value. -// If a WithKey option is provided, its private part signs the returned cert, -// and the public part is used to specify the signee. -// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. -// -// The returned certificate is valid for the next 24 hours and must be presented only when -// the server name of the TLS ClientHello matches exactly the returned name value. -func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { - ka, err := keyAuth(c.Key.Public(), token) - if err != nil { - return tls.Certificate{}, "", err - } - b := sha256.Sum256([]byte(ka)) - h := hex.EncodeToString(b[:]) - name = fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:]) - cert, err = tlsChallengeCert([]string{name}, opt) - if err != nil { - return tls.Certificate{}, "", err - } - return cert, name, nil -} - -// TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response. -// Servers can present the certificate to validate the challenge and prove control -// over a domain name. For more details on TLS-SNI-02 see -// https://tools.ietf.org/html/draft-ietf-acme-acme-03#section-7.3. -// -// The token argument is a Challenge.Token value. -// If a WithKey option is provided, its private part signs the returned cert, -// and the public part is used to specify the signee. -// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. -// -// The returned certificate is valid for the next 24 hours and must be presented only when -// the server name in the TLS ClientHello matches exactly the returned name value. -func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { - b := sha256.Sum256([]byte(token)) - h := hex.EncodeToString(b[:]) - sanA := fmt.Sprintf("%s.%s.token.acme.invalid", h[:32], h[32:]) - - ka, err := keyAuth(c.Key.Public(), token) - if err != nil { - return tls.Certificate{}, "", err - } - b = sha256.Sum256([]byte(ka)) - h = hex.EncodeToString(b[:]) - sanB := fmt.Sprintf("%s.%s.ka.acme.invalid", h[:32], h[32:]) - - cert, err = tlsChallengeCert([]string{sanA, sanB}, opt) - if err != nil { - return tls.Certificate{}, "", err - } - return cert, sanA, nil -} - -// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response. -// Servers can present the certificate to validate the challenge and prove control -// over a domain name. For more details on TLS-ALPN-01 see -// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3 -// -// The token argument is a Challenge.Token value. -// If a WithKey option is provided, its private part signs the returned cert, -// and the public part is used to specify the signee. -// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. -// -// The returned certificate is valid for the next 24 hours and must be presented only when -// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol -// has been specified. -func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) { - ka, err := keyAuth(c.Key.Public(), token) - if err != nil { - return tls.Certificate{}, err - } - shasum := sha256.Sum256([]byte(ka)) - extValue, err := asn1.Marshal(shasum[:]) - if err != nil { - return tls.Certificate{}, err - } - acmeExtension := pkix.Extension{ - Id: idPeACMEIdentifierV1, - Critical: true, - Value: extValue, - } - - tmpl := defaultTLSChallengeCertTemplate() - - var newOpt []CertOption - for _, o := range opt { - switch o := o.(type) { - case *certOptTemplate: - t := *(*x509.Certificate)(o) // shallow copy is ok - tmpl = &t - default: - newOpt = append(newOpt, o) - } - } - tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension) - newOpt = append(newOpt, WithTemplate(tmpl)) - return tlsChallengeCert([]string{domain}, newOpt) -} - -// doReg sends all types of registration requests. -// The type of request is identified by typ argument, which is a "resource" -// in the ACME spec terms. -// -// A non-nil acct argument indicates whether the intention is to mutate data -// of the Account. Only Contact and Agreement of its fields are used -// in such cases. -func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Account) (*Account, error) { - req := struct { - Resource string `json:"resource"` - Contact []string `json:"contact,omitempty"` - Agreement string `json:"agreement,omitempty"` - }{ - Resource: typ, - } - if acct != nil { - req.Contact = acct.Contact - req.Agreement = acct.AgreedTerms - } - res, err := c.post(ctx, c.Key, url, req, wantStatus( - http.StatusOK, // updates and deletes - http.StatusCreated, // new account creation - http.StatusAccepted, // Let's Encrypt divergent implementation - )) - if err != nil { - return nil, err - } - defer res.Body.Close() - - var v struct { - Contact []string - Agreement string - Authorizations string - Certificates string - } - if err := json.NewDecoder(res.Body).Decode(&v); err != nil { - return nil, fmt.Errorf("acme: invalid response: %v", err) - } - var tos string - if v := linkHeader(res.Header, "terms-of-service"); len(v) > 0 { - tos = v[0] - } - var authz string - if v := linkHeader(res.Header, "next"); len(v) > 0 { - authz = v[0] - } - return &Account{ - URI: res.Header.Get("Location"), - Contact: v.Contact, - AgreedTerms: v.Agreement, - CurrentTerms: tos, - Authz: authz, - Authorizations: v.Authorizations, - Certificates: v.Certificates, - }, nil -} - -// popNonce returns a nonce value previously stored with c.addNonce -// or fetches a fresh one from the given URL. -func (c *Client) popNonce(ctx context.Context, url string) (string, error) { - c.noncesMu.Lock() - defer c.noncesMu.Unlock() - if len(c.nonces) == 0 { - return c.fetchNonce(ctx, url) - } - var nonce string - for nonce = range c.nonces { - delete(c.nonces, nonce) - break - } - return nonce, nil -} - -// clearNonces clears any stored nonces -func (c *Client) clearNonces() { - c.noncesMu.Lock() - defer c.noncesMu.Unlock() - c.nonces = make(map[string]struct{}) -} - -// addNonce stores a nonce value found in h (if any) for future use. -func (c *Client) addNonce(h http.Header) { - v := nonceFromHeader(h) - if v == "" { - return - } - c.noncesMu.Lock() - defer c.noncesMu.Unlock() - if len(c.nonces) >= maxNonces { - return - } - if c.nonces == nil { - c.nonces = make(map[string]struct{}) - } - c.nonces[v] = struct{}{} -} - -func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) { - r, err := http.NewRequest("HEAD", url, nil) - if err != nil { - return "", err - } - resp, err := c.doNoRetry(ctx, r) - if err != nil { - return "", err - } - defer resp.Body.Close() - nonce := nonceFromHeader(resp.Header) - if nonce == "" { - if resp.StatusCode > 299 { - return "", responseError(resp) - } - return "", errors.New("acme: nonce not found") - } - return nonce, nil -} - -func nonceFromHeader(h http.Header) string { - return h.Get("Replay-Nonce") -} - -func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bool) ([][]byte, error) { - b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1)) - if err != nil { - return nil, fmt.Errorf("acme: response stream: %v", err) - } - if len(b) > maxCertSize { - return nil, errors.New("acme: certificate is too big") - } - cert := [][]byte{b} - if !bundle { - return cert, nil - } - - // Append CA chain cert(s). - // At least one is required according to the spec: - // https://tools.ietf.org/html/draft-ietf-acme-acme-03#section-6.3.1 - up := linkHeader(res.Header, "up") - if len(up) == 0 { - return nil, errors.New("acme: rel=up link not found") - } - if len(up) > maxChainLen { - return nil, errors.New("acme: rel=up link is too large") - } - for _, url := range up { - cc, err := c.chainCert(ctx, url, 0) - if err != nil { - return nil, err - } - cert = append(cert, cc...) - } - return cert, nil -} - -// chainCert fetches CA certificate chain recursively by following "up" links. -// Each recursive call increments the depth by 1, resulting in an error -// if the recursion level reaches maxChainLen. -// -// First chainCert call starts with depth of 0. -func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte, error) { - if depth >= maxChainLen { - return nil, errors.New("acme: certificate chain is too deep") - } - - res, err := c.get(ctx, url, wantStatus(http.StatusOK)) - if err != nil { - return nil, err - } - defer res.Body.Close() - b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1)) - if err != nil { - return nil, err - } - if len(b) > maxCertSize { - return nil, errors.New("acme: certificate is too big") - } - chain := [][]byte{b} - - uplink := linkHeader(res.Header, "up") - if len(uplink) > maxChainLen { - return nil, errors.New("acme: certificate chain is too large") - } - for _, up := range uplink { - cc, err := c.chainCert(ctx, up, depth+1) - if err != nil { - return nil, err - } - chain = append(chain, cc...) - } - - return chain, nil -} - -// linkHeader returns URI-Reference values of all Link headers -// with relation-type rel. -// See https://tools.ietf.org/html/rfc5988#section-5 for details. -func linkHeader(h http.Header, rel string) []string { - var links []string - for _, v := range h["Link"] { - parts := strings.Split(v, ";") - for _, p := range parts { - p = strings.TrimSpace(p) - if !strings.HasPrefix(p, "rel=") { - continue - } - if v := strings.Trim(p[4:], `"`); v == rel { - links = append(links, strings.Trim(parts[0], "<>")) - } - } - } - return links -} - -// keyAuth generates a key authorization string for a given token. -func keyAuth(pub crypto.PublicKey, token string) (string, error) { - th, err := JWKThumbprint(pub) - if err != nil { - return "", err - } - return fmt.Sprintf("%s.%s", token, th), nil -} - -// defaultTLSChallengeCertTemplate is a template used to create challenge certs for TLS challenges. -func defaultTLSChallengeCertTemplate() *x509.Certificate { - return &x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - BasicConstraintsValid: true, - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - } -} - -// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges -// with the given SANs and auto-generated public/private key pair. -// The Subject Common Name is set to the first SAN to aid debugging. -// To create a cert with a custom key pair, specify WithKey option. -func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) { - var key crypto.Signer - tmpl := defaultTLSChallengeCertTemplate() - for _, o := range opt { - switch o := o.(type) { - case *certOptKey: - if key != nil { - return tls.Certificate{}, errors.New("acme: duplicate key option") - } - key = o.key - case *certOptTemplate: - t := *(*x509.Certificate)(o) // shallow copy is ok - tmpl = &t - default: - // package's fault, if we let this happen: - panic(fmt.Sprintf("unsupported option type %T", o)) - } - } - if key == nil { - var err error - if key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader); err != nil { - return tls.Certificate{}, err - } - } - tmpl.DNSNames = san - if len(san) > 0 { - tmpl.Subject.CommonName = san[0] - } - - der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) - if err != nil { - return tls.Certificate{}, err - } - return tls.Certificate{ - Certificate: [][]byte{der}, - PrivateKey: key, - }, nil -} - -// encodePEM returns b encoded as PEM with block of type typ. -func encodePEM(typ string, b []byte) []byte { - pb := &pem.Block{Type: typ, Bytes: b} - return pem.EncodeToMemory(pb) -} - -// timeNow is useful for testing for fixed current time. -var timeNow = time.Now diff --git a/vendor/golang.org/x/crypto/acme/autocert/autocert.go b/vendor/golang.org/x/crypto/acme/autocert/autocert.go deleted file mode 100644 index 1a9d972d37..0000000000 --- a/vendor/golang.org/x/crypto/acme/autocert/autocert.go +++ /dev/null @@ -1,1127 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package autocert provides automatic access to certificates from Let's Encrypt -// and any other ACME-based CA. -// -// This package is a work in progress and makes no API stability promises. -package autocert - -import ( - "bytes" - "context" - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "errors" - "fmt" - "io" - mathrand "math/rand" - "net" - "net/http" - "path" - "strings" - "sync" - "time" - - "golang.org/x/crypto/acme" -) - -// createCertRetryAfter is how much time to wait before removing a failed state -// entry due to an unsuccessful createCert call. -// This is a variable instead of a const for testing. -// TODO: Consider making it configurable or an exp backoff? -var createCertRetryAfter = time.Minute - -// pseudoRand is safe for concurrent use. -var pseudoRand *lockedMathRand - -func init() { - src := mathrand.NewSource(timeNow().UnixNano()) - pseudoRand = &lockedMathRand{rnd: mathrand.New(src)} -} - -// AcceptTOS is a Manager.Prompt function that always returns true to -// indicate acceptance of the CA's Terms of Service during account -// registration. -func AcceptTOS(tosURL string) bool { return true } - -// HostPolicy specifies which host names the Manager is allowed to respond to. -// It returns a non-nil error if the host should be rejected. -// The returned error is accessible via tls.Conn.Handshake and its callers. -// See Manager's HostPolicy field and GetCertificate method docs for more details. -type HostPolicy func(ctx context.Context, host string) error - -// HostWhitelist returns a policy where only the specified host names are allowed. -// Only exact matches are currently supported. Subdomains, regexp or wildcard -// will not match. -func HostWhitelist(hosts ...string) HostPolicy { - whitelist := make(map[string]bool, len(hosts)) - for _, h := range hosts { - whitelist[h] = true - } - return func(_ context.Context, host string) error { - if !whitelist[host] { - return errors.New("acme/autocert: host not configured") - } - return nil - } -} - -// defaultHostPolicy is used when Manager.HostPolicy is not set. -func defaultHostPolicy(context.Context, string) error { - return nil -} - -// Manager is a stateful certificate manager built on top of acme.Client. -// It obtains and refreshes certificates automatically using "tls-alpn-01", -// "tls-sni-01", "tls-sni-02" and "http-01" challenge types, -// as well as providing them to a TLS server via tls.Config. -// -// You must specify a cache implementation, such as DirCache, -// to reuse obtained certificates across program restarts. -// Otherwise your server is very likely to exceed the certificate -// issuer's request rate limits. -type Manager struct { - // Prompt specifies a callback function to conditionally accept a CA's Terms of Service (TOS). - // The registration may require the caller to agree to the CA's TOS. - // If so, Manager calls Prompt with a TOS URL provided by the CA. Prompt should report - // whether the caller agrees to the terms. - // - // To always accept the terms, the callers can use AcceptTOS. - Prompt func(tosURL string) bool - - // Cache optionally stores and retrieves previously-obtained certificates - // and other state. If nil, certs will only be cached for the lifetime of - // the Manager. Multiple Managers can share the same Cache. - // - // Using a persistent Cache, such as DirCache, is strongly recommended. - Cache Cache - - // HostPolicy controls which domains the Manager will attempt - // to retrieve new certificates for. It does not affect cached certs. - // - // If non-nil, HostPolicy is called before requesting a new cert. - // If nil, all hosts are currently allowed. This is not recommended, - // as it opens a potential attack where clients connect to a server - // by IP address and pretend to be asking for an incorrect host name. - // Manager will attempt to obtain a certificate for that host, incorrectly, - // eventually reaching the CA's rate limit for certificate requests - // and making it impossible to obtain actual certificates. - // - // See GetCertificate for more details. - HostPolicy HostPolicy - - // RenewBefore optionally specifies how early certificates should - // be renewed before they expire. - // - // If zero, they're renewed 30 days before expiration. - RenewBefore time.Duration - - // Client is used to perform low-level operations, such as account registration - // and requesting new certificates. - // - // If Client is nil, a zero-value acme.Client is used with acme.LetsEncryptURL - // as directory endpoint. If the Client.Key is nil, a new ECDSA P-256 key is - // generated and, if Cache is not nil, stored in cache. - // - // Mutating the field after the first call of GetCertificate method will have no effect. - Client *acme.Client - - // Email optionally specifies a contact email address. - // This is used by CAs, such as Let's Encrypt, to notify about problems - // with issued certificates. - // - // If the Client's account key is already registered, Email is not used. - Email string - - // ForceRSA used to make the Manager generate RSA certificates. It is now ignored. - // - // Deprecated: the Manager will request the correct type of certificate based - // on what each client supports. - ForceRSA bool - - // ExtraExtensions are used when generating a new CSR (Certificate Request), - // thus allowing customization of the resulting certificate. - // For instance, TLS Feature Extension (RFC 7633) can be used - // to prevent an OCSP downgrade attack. - // - // The field value is passed to crypto/x509.CreateCertificateRequest - // in the template's ExtraExtensions field as is. - ExtraExtensions []pkix.Extension - - clientMu sync.Mutex - client *acme.Client // initialized by acmeClient method - - stateMu sync.Mutex - state map[certKey]*certState - - // renewal tracks the set of domains currently running renewal timers. - renewalMu sync.Mutex - renewal map[certKey]*domainRenewal - - // tokensMu guards the rest of the fields: tryHTTP01, certTokens and httpTokens. - tokensMu sync.RWMutex - // tryHTTP01 indicates whether the Manager should try "http-01" challenge type - // during the authorization flow. - tryHTTP01 bool - // httpTokens contains response body values for http-01 challenges - // and is keyed by the URL path at which a challenge response is expected - // to be provisioned. - // The entries are stored for the duration of the authorization flow. - httpTokens map[string][]byte - // certTokens contains temporary certificates for tls-sni and tls-alpn challenges - // and is keyed by token domain name, which matches server name of ClientHello. - // Keys always have ".acme.invalid" suffix for tls-sni. Otherwise, they are domain names - // for tls-alpn. - // The entries are stored for the duration of the authorization flow. - certTokens map[string]*tls.Certificate -} - -// certKey is the key by which certificates are tracked in state, renewal and cache. -type certKey struct { - domain string // without trailing dot - isRSA bool // RSA cert for legacy clients (as opposed to default ECDSA) - isToken bool // tls-based challenge token cert; key type is undefined regardless of isRSA -} - -func (c certKey) String() string { - if c.isToken { - return c.domain + "+token" - } - if c.isRSA { - return c.domain + "+rsa" - } - return c.domain -} - -// TLSConfig creates a new TLS config suitable for net/http.Server servers, -// supporting HTTP/2 and the tls-alpn-01 ACME challenge type. -func (m *Manager) TLSConfig() *tls.Config { - return &tls.Config{ - GetCertificate: m.GetCertificate, - NextProtos: []string{ - "h2", "http/1.1", // enable HTTP/2 - acme.ALPNProto, // enable tls-alpn ACME challenges - }, - } -} - -// GetCertificate implements the tls.Config.GetCertificate hook. -// It provides a TLS certificate for hello.ServerName host, including answering -// tls-alpn-01 and *.acme.invalid (tls-sni-01 and tls-sni-02) challenges. -// All other fields of hello are ignored. -// -// If m.HostPolicy is non-nil, GetCertificate calls the policy before requesting -// a new cert. A non-nil error returned from m.HostPolicy halts TLS negotiation. -// The error is propagated back to the caller of GetCertificate and is user-visible. -// This does not affect cached certs. See HostPolicy field description for more details. -func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - if m.Prompt == nil { - return nil, errors.New("acme/autocert: Manager.Prompt not set") - } - - name := hello.ServerName - if name == "" { - return nil, errors.New("acme/autocert: missing server name") - } - if !strings.Contains(strings.Trim(name, "."), ".") { - return nil, errors.New("acme/autocert: server name component count invalid") - } - if strings.ContainsAny(name, `+/\`) { - return nil, errors.New("acme/autocert: server name contains invalid character") - } - - // In the worst-case scenario, the timeout needs to account for caching, host policy, - // domain ownership verification and certificate issuance. - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - // Check whether this is a token cert requested for TLS-SNI or TLS-ALPN challenge. - if wantsTokenCert(hello) { - m.tokensMu.RLock() - defer m.tokensMu.RUnlock() - // It's ok to use the same token cert key for both tls-sni and tls-alpn - // because there's always at most 1 token cert per on-going domain authorization. - // See m.verify for details. - if cert := m.certTokens[name]; cert != nil { - return cert, nil - } - if cert, err := m.cacheGet(ctx, certKey{domain: name, isToken: true}); err == nil { - return cert, nil - } - // TODO: cache error results? - return nil, fmt.Errorf("acme/autocert: no token cert for %q", name) - } - - // regular domain - ck := certKey{ - domain: strings.TrimSuffix(name, "."), // golang.org/issue/18114 - isRSA: !supportsECDSA(hello), - } - cert, err := m.cert(ctx, ck) - if err == nil { - return cert, nil - } - if err != ErrCacheMiss { - return nil, err - } - - // first-time - if err := m.hostPolicy()(ctx, name); err != nil { - return nil, err - } - cert, err = m.createCert(ctx, ck) - if err != nil { - return nil, err - } - m.cachePut(ctx, ck, cert) - return cert, nil -} - -// wantsTokenCert reports whether a TLS request with SNI is made by a CA server -// for a challenge verification. -func wantsTokenCert(hello *tls.ClientHelloInfo) bool { - // tls-alpn-01 - if len(hello.SupportedProtos) == 1 && hello.SupportedProtos[0] == acme.ALPNProto { - return true - } - // tls-sni-xx - return strings.HasSuffix(hello.ServerName, ".acme.invalid") -} - -func supportsECDSA(hello *tls.ClientHelloInfo) bool { - // The "signature_algorithms" extension, if present, limits the key exchange - // algorithms allowed by the cipher suites. See RFC 5246, section 7.4.1.4.1. - if hello.SignatureSchemes != nil { - ecdsaOK := false - schemeLoop: - for _, scheme := range hello.SignatureSchemes { - const tlsECDSAWithSHA1 tls.SignatureScheme = 0x0203 // constant added in Go 1.10 - switch scheme { - case tlsECDSAWithSHA1, tls.ECDSAWithP256AndSHA256, - tls.ECDSAWithP384AndSHA384, tls.ECDSAWithP521AndSHA512: - ecdsaOK = true - break schemeLoop - } - } - if !ecdsaOK { - return false - } - } - if hello.SupportedCurves != nil { - ecdsaOK := false - for _, curve := range hello.SupportedCurves { - if curve == tls.CurveP256 { - ecdsaOK = true - break - } - } - if !ecdsaOK { - return false - } - } - for _, suite := range hello.CipherSuites { - switch suite { - case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: - return true - } - } - return false -} - -// HTTPHandler configures the Manager to provision ACME "http-01" challenge responses. -// It returns an http.Handler that responds to the challenges and must be -// running on port 80. If it receives a request that is not an ACME challenge, -// it delegates the request to the optional fallback handler. -// -// If fallback is nil, the returned handler redirects all GET and HEAD requests -// to the default TLS port 443 with 302 Found status code, preserving the original -// request path and query. It responds with 400 Bad Request to all other HTTP methods. -// The fallback is not protected by the optional HostPolicy. -// -// Because the fallback handler is run with unencrypted port 80 requests, -// the fallback should not serve TLS-only requests. -// -// If HTTPHandler is never called, the Manager will only use TLS SNI -// challenges for domain verification. -func (m *Manager) HTTPHandler(fallback http.Handler) http.Handler { - m.tokensMu.Lock() - defer m.tokensMu.Unlock() - m.tryHTTP01 = true - - if fallback == nil { - fallback = http.HandlerFunc(handleHTTPRedirect) - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !strings.HasPrefix(r.URL.Path, "/.well-known/acme-challenge/") { - fallback.ServeHTTP(w, r) - return - } - // A reasonable context timeout for cache and host policy only, - // because we don't wait for a new certificate issuance here. - ctx, cancel := context.WithTimeout(r.Context(), time.Minute) - defer cancel() - if err := m.hostPolicy()(ctx, r.Host); err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return - } - data, err := m.httpToken(ctx, r.URL.Path) - if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - w.Write(data) - }) -} - -func handleHTTPRedirect(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" && r.Method != "HEAD" { - http.Error(w, "Use HTTPS", http.StatusBadRequest) - return - } - target := "https://" + stripPort(r.Host) + r.URL.RequestURI() - http.Redirect(w, r, target, http.StatusFound) -} - -func stripPort(hostport string) string { - host, _, err := net.SplitHostPort(hostport) - if err != nil { - return hostport - } - return net.JoinHostPort(host, "443") -} - -// cert returns an existing certificate either from m.state or cache. -// If a certificate is found in cache but not in m.state, the latter will be filled -// with the cached value. -func (m *Manager) cert(ctx context.Context, ck certKey) (*tls.Certificate, error) { - m.stateMu.Lock() - if s, ok := m.state[ck]; ok { - m.stateMu.Unlock() - s.RLock() - defer s.RUnlock() - return s.tlscert() - } - defer m.stateMu.Unlock() - cert, err := m.cacheGet(ctx, ck) - if err != nil { - return nil, err - } - signer, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil, errors.New("acme/autocert: private key cannot sign") - } - if m.state == nil { - m.state = make(map[certKey]*certState) - } - s := &certState{ - key: signer, - cert: cert.Certificate, - leaf: cert.Leaf, - } - m.state[ck] = s - go m.renew(ck, s.key, s.leaf.NotAfter) - return cert, nil -} - -// cacheGet always returns a valid certificate, or an error otherwise. -// If a cached certificate exists but is not valid, ErrCacheMiss is returned. -func (m *Manager) cacheGet(ctx context.Context, ck certKey) (*tls.Certificate, error) { - if m.Cache == nil { - return nil, ErrCacheMiss - } - data, err := m.Cache.Get(ctx, ck.String()) - if err != nil { - return nil, err - } - - // private - priv, pub := pem.Decode(data) - if priv == nil || !strings.Contains(priv.Type, "PRIVATE") { - return nil, ErrCacheMiss - } - privKey, err := parsePrivateKey(priv.Bytes) - if err != nil { - return nil, err - } - - // public - var pubDER [][]byte - for len(pub) > 0 { - var b *pem.Block - b, pub = pem.Decode(pub) - if b == nil { - break - } - pubDER = append(pubDER, b.Bytes) - } - if len(pub) > 0 { - // Leftover content not consumed by pem.Decode. Corrupt. Ignore. - return nil, ErrCacheMiss - } - - // verify and create TLS cert - leaf, err := validCert(ck, pubDER, privKey) - if err != nil { - return nil, ErrCacheMiss - } - tlscert := &tls.Certificate{ - Certificate: pubDER, - PrivateKey: privKey, - Leaf: leaf, - } - return tlscert, nil -} - -func (m *Manager) cachePut(ctx context.Context, ck certKey, tlscert *tls.Certificate) error { - if m.Cache == nil { - return nil - } - - // contains PEM-encoded data - var buf bytes.Buffer - - // private - switch key := tlscert.PrivateKey.(type) { - case *ecdsa.PrivateKey: - if err := encodeECDSAKey(&buf, key); err != nil { - return err - } - case *rsa.PrivateKey: - b := x509.MarshalPKCS1PrivateKey(key) - pb := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: b} - if err := pem.Encode(&buf, pb); err != nil { - return err - } - default: - return errors.New("acme/autocert: unknown private key type") - } - - // public - for _, b := range tlscert.Certificate { - pb := &pem.Block{Type: "CERTIFICATE", Bytes: b} - if err := pem.Encode(&buf, pb); err != nil { - return err - } - } - - return m.Cache.Put(ctx, ck.String(), buf.Bytes()) -} - -func encodeECDSAKey(w io.Writer, key *ecdsa.PrivateKey) error { - b, err := x509.MarshalECPrivateKey(key) - if err != nil { - return err - } - pb := &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} - return pem.Encode(w, pb) -} - -// createCert starts the domain ownership verification and returns a certificate -// for that domain upon success. -// -// If the domain is already being verified, it waits for the existing verification to complete. -// Either way, createCert blocks for the duration of the whole process. -func (m *Manager) createCert(ctx context.Context, ck certKey) (*tls.Certificate, error) { - // TODO: maybe rewrite this whole piece using sync.Once - state, err := m.certState(ck) - if err != nil { - return nil, err - } - // state may exist if another goroutine is already working on it - // in which case just wait for it to finish - if !state.locked { - state.RLock() - defer state.RUnlock() - return state.tlscert() - } - - // We are the first; state is locked. - // Unblock the readers when domain ownership is verified - // and we got the cert or the process failed. - defer state.Unlock() - state.locked = false - - der, leaf, err := m.authorizedCert(ctx, state.key, ck) - if err != nil { - // Remove the failed state after some time, - // making the manager call createCert again on the following TLS hello. - time.AfterFunc(createCertRetryAfter, func() { - defer testDidRemoveState(ck) - m.stateMu.Lock() - defer m.stateMu.Unlock() - // Verify the state hasn't changed and it's still invalid - // before deleting. - s, ok := m.state[ck] - if !ok { - return - } - if _, err := validCert(ck, s.cert, s.key); err == nil { - return - } - delete(m.state, ck) - }) - return nil, err - } - state.cert = der - state.leaf = leaf - go m.renew(ck, state.key, state.leaf.NotAfter) - return state.tlscert() -} - -// certState returns a new or existing certState. -// If a new certState is returned, state.exist is false and the state is locked. -// The returned error is non-nil only in the case where a new state could not be created. -func (m *Manager) certState(ck certKey) (*certState, error) { - m.stateMu.Lock() - defer m.stateMu.Unlock() - if m.state == nil { - m.state = make(map[certKey]*certState) - } - // existing state - if state, ok := m.state[ck]; ok { - return state, nil - } - - // new locked state - var ( - err error - key crypto.Signer - ) - if ck.isRSA { - key, err = rsa.GenerateKey(rand.Reader, 2048) - } else { - key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - } - if err != nil { - return nil, err - } - - state := &certState{ - key: key, - locked: true, - } - state.Lock() // will be unlocked by m.certState caller - m.state[ck] = state - return state, nil -} - -// authorizedCert starts the domain ownership verification process and requests a new cert upon success. -// The key argument is the certificate private key. -func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, ck certKey) (der [][]byte, leaf *x509.Certificate, err error) { - client, err := m.acmeClient(ctx) - if err != nil { - return nil, nil, err - } - - if err := m.verify(ctx, client, ck.domain); err != nil { - return nil, nil, err - } - csr, err := certRequest(key, ck.domain, m.ExtraExtensions) - if err != nil { - return nil, nil, err - } - der, _, err = client.CreateCert(ctx, csr, 0, true) - if err != nil { - return nil, nil, err - } - leaf, err = validCert(ck, der, key) - if err != nil { - return nil, nil, err - } - return der, leaf, nil -} - -// revokePendingAuthz revokes all authorizations idenfied by the elements of uri slice. -// It ignores revocation errors. -func (m *Manager) revokePendingAuthz(ctx context.Context, uri []string) { - client, err := m.acmeClient(ctx) - if err != nil { - return - } - for _, u := range uri { - client.RevokeAuthorization(ctx, u) - } -} - -// verify runs the identifier (domain) authorization flow -// using each applicable ACME challenge type. -func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string) error { - // The list of challenge types we'll try to fulfill - // in this specific order. - challengeTypes := []string{"tls-alpn-01", "tls-sni-02", "tls-sni-01"} - m.tokensMu.RLock() - if m.tryHTTP01 { - challengeTypes = append(challengeTypes, "http-01") - } - m.tokensMu.RUnlock() - - // Keep track of pending authzs and revoke the ones that did not validate. - pendingAuthzs := make(map[string]bool) - defer func() { - var uri []string - for k, pending := range pendingAuthzs { - if pending { - uri = append(uri, k) - } - } - if len(uri) > 0 { - // Use "detached" background context. - // The revocations need not happen in the current verification flow. - go m.revokePendingAuthz(context.Background(), uri) - } - }() - - // errs accumulates challenge failure errors, printed if all fail - errs := make(map[*acme.Challenge]error) - var nextTyp int // challengeType index of the next challenge type to try - for { - // Start domain authorization and get the challenge. - authz, err := client.Authorize(ctx, domain) - if err != nil { - return err - } - // No point in accepting challenges if the authorization status - // is in a final state. - switch authz.Status { - case acme.StatusValid: - return nil // already authorized - case acme.StatusInvalid: - return fmt.Errorf("acme/autocert: invalid authorization %q", authz.URI) - } - - pendingAuthzs[authz.URI] = true - - // Pick the next preferred challenge. - var chal *acme.Challenge - for chal == nil && nextTyp < len(challengeTypes) { - chal = pickChallenge(challengeTypes[nextTyp], authz.Challenges) - nextTyp++ - } - if chal == nil { - errorMsg := fmt.Sprintf("acme/autocert: unable to authorize %q", domain) - for chal, err := range errs { - errorMsg += fmt.Sprintf("; challenge %q failed with error: %v", chal.Type, err) - } - return errors.New(errorMsg) - } - cleanup, err := m.fulfill(ctx, client, chal, domain) - if err != nil { - errs[chal] = err - continue - } - defer cleanup() - if _, err := client.Accept(ctx, chal); err != nil { - errs[chal] = err - continue - } - - // A challenge is fulfilled and accepted: wait for the CA to validate. - if _, err := client.WaitAuthorization(ctx, authz.URI); err != nil { - errs[chal] = err - continue - } - delete(pendingAuthzs, authz.URI) - return nil - } -} - -// fulfill provisions a response to the challenge chal. -// The cleanup is non-nil only if provisioning succeeded. -func (m *Manager) fulfill(ctx context.Context, client *acme.Client, chal *acme.Challenge, domain string) (cleanup func(), err error) { - switch chal.Type { - case "tls-alpn-01": - cert, err := client.TLSALPN01ChallengeCert(chal.Token, domain) - if err != nil { - return nil, err - } - m.putCertToken(ctx, domain, &cert) - return func() { go m.deleteCertToken(domain) }, nil - case "tls-sni-01": - cert, name, err := client.TLSSNI01ChallengeCert(chal.Token) - if err != nil { - return nil, err - } - m.putCertToken(ctx, name, &cert) - return func() { go m.deleteCertToken(name) }, nil - case "tls-sni-02": - cert, name, err := client.TLSSNI02ChallengeCert(chal.Token) - if err != nil { - return nil, err - } - m.putCertToken(ctx, name, &cert) - return func() { go m.deleteCertToken(name) }, nil - case "http-01": - resp, err := client.HTTP01ChallengeResponse(chal.Token) - if err != nil { - return nil, err - } - p := client.HTTP01ChallengePath(chal.Token) - m.putHTTPToken(ctx, p, resp) - return func() { go m.deleteHTTPToken(p) }, nil - } - return nil, fmt.Errorf("acme/autocert: unknown challenge type %q", chal.Type) -} - -func pickChallenge(typ string, chal []*acme.Challenge) *acme.Challenge { - for _, c := range chal { - if c.Type == typ { - return c - } - } - return nil -} - -// putCertToken stores the token certificate with the specified name -// in both m.certTokens map and m.Cache. -func (m *Manager) putCertToken(ctx context.Context, name string, cert *tls.Certificate) { - m.tokensMu.Lock() - defer m.tokensMu.Unlock() - if m.certTokens == nil { - m.certTokens = make(map[string]*tls.Certificate) - } - m.certTokens[name] = cert - m.cachePut(ctx, certKey{domain: name, isToken: true}, cert) -} - -// deleteCertToken removes the token certificate with the specified name -// from both m.certTokens map and m.Cache. -func (m *Manager) deleteCertToken(name string) { - m.tokensMu.Lock() - defer m.tokensMu.Unlock() - delete(m.certTokens, name) - if m.Cache != nil { - ck := certKey{domain: name, isToken: true} - m.Cache.Delete(context.Background(), ck.String()) - } -} - -// httpToken retrieves an existing http-01 token value from an in-memory map -// or the optional cache. -func (m *Manager) httpToken(ctx context.Context, tokenPath string) ([]byte, error) { - m.tokensMu.RLock() - defer m.tokensMu.RUnlock() - if v, ok := m.httpTokens[tokenPath]; ok { - return v, nil - } - if m.Cache == nil { - return nil, fmt.Errorf("acme/autocert: no token at %q", tokenPath) - } - return m.Cache.Get(ctx, httpTokenCacheKey(tokenPath)) -} - -// putHTTPToken stores an http-01 token value using tokenPath as key -// in both in-memory map and the optional Cache. -// -// It ignores any error returned from Cache.Put. -func (m *Manager) putHTTPToken(ctx context.Context, tokenPath, val string) { - m.tokensMu.Lock() - defer m.tokensMu.Unlock() - if m.httpTokens == nil { - m.httpTokens = make(map[string][]byte) - } - b := []byte(val) - m.httpTokens[tokenPath] = b - if m.Cache != nil { - m.Cache.Put(ctx, httpTokenCacheKey(tokenPath), b) - } -} - -// deleteHTTPToken removes an http-01 token value from both in-memory map -// and the optional Cache, ignoring any error returned from the latter. -// -// If m.Cache is non-nil, it blocks until Cache.Delete returns without a timeout. -func (m *Manager) deleteHTTPToken(tokenPath string) { - m.tokensMu.Lock() - defer m.tokensMu.Unlock() - delete(m.httpTokens, tokenPath) - if m.Cache != nil { - m.Cache.Delete(context.Background(), httpTokenCacheKey(tokenPath)) - } -} - -// httpTokenCacheKey returns a key at which an http-01 token value may be stored -// in the Manager's optional Cache. -func httpTokenCacheKey(tokenPath string) string { - return path.Base(tokenPath) + "+http-01" -} - -// renew starts a cert renewal timer loop, one per domain. -// -// The loop is scheduled in two cases: -// - a cert was fetched from cache for the first time (wasn't in m.state) -// - a new cert was created by m.createCert -// -// The key argument is a certificate private key. -// The exp argument is the cert expiration time (NotAfter). -func (m *Manager) renew(ck certKey, key crypto.Signer, exp time.Time) { - m.renewalMu.Lock() - defer m.renewalMu.Unlock() - if m.renewal[ck] != nil { - // another goroutine is already on it - return - } - if m.renewal == nil { - m.renewal = make(map[certKey]*domainRenewal) - } - dr := &domainRenewal{m: m, ck: ck, key: key} - m.renewal[ck] = dr - dr.start(exp) -} - -// stopRenew stops all currently running cert renewal timers. -// The timers are not restarted during the lifetime of the Manager. -func (m *Manager) stopRenew() { - m.renewalMu.Lock() - defer m.renewalMu.Unlock() - for name, dr := range m.renewal { - delete(m.renewal, name) - dr.stop() - } -} - -func (m *Manager) accountKey(ctx context.Context) (crypto.Signer, error) { - const keyName = "acme_account+key" - - // Previous versions of autocert stored the value under a different key. - const legacyKeyName = "acme_account.key" - - genKey := func() (*ecdsa.PrivateKey, error) { - return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - } - - if m.Cache == nil { - return genKey() - } - - data, err := m.Cache.Get(ctx, keyName) - if err == ErrCacheMiss { - data, err = m.Cache.Get(ctx, legacyKeyName) - } - if err == ErrCacheMiss { - key, err := genKey() - if err != nil { - return nil, err - } - var buf bytes.Buffer - if err := encodeECDSAKey(&buf, key); err != nil { - return nil, err - } - if err := m.Cache.Put(ctx, keyName, buf.Bytes()); err != nil { - return nil, err - } - return key, nil - } - if err != nil { - return nil, err - } - - priv, _ := pem.Decode(data) - if priv == nil || !strings.Contains(priv.Type, "PRIVATE") { - return nil, errors.New("acme/autocert: invalid account key found in cache") - } - return parsePrivateKey(priv.Bytes) -} - -func (m *Manager) acmeClient(ctx context.Context) (*acme.Client, error) { - m.clientMu.Lock() - defer m.clientMu.Unlock() - if m.client != nil { - return m.client, nil - } - - client := m.Client - if client == nil { - client = &acme.Client{DirectoryURL: acme.LetsEncryptURL} - } - if client.Key == nil { - var err error - client.Key, err = m.accountKey(ctx) - if err != nil { - return nil, err - } - } - var contact []string - if m.Email != "" { - contact = []string{"mailto:" + m.Email} - } - a := &acme.Account{Contact: contact} - _, err := client.Register(ctx, a, m.Prompt) - if ae, ok := err.(*acme.Error); err == nil || ok && ae.StatusCode == http.StatusConflict { - // conflict indicates the key is already registered - m.client = client - err = nil - } - return m.client, err -} - -func (m *Manager) hostPolicy() HostPolicy { - if m.HostPolicy != nil { - return m.HostPolicy - } - return defaultHostPolicy -} - -func (m *Manager) renewBefore() time.Duration { - if m.RenewBefore > renewJitter { - return m.RenewBefore - } - return 720 * time.Hour // 30 days -} - -// certState is ready when its mutex is unlocked for reading. -type certState struct { - sync.RWMutex - locked bool // locked for read/write - key crypto.Signer // private key for cert - cert [][]byte // DER encoding - leaf *x509.Certificate // parsed cert[0]; always non-nil if cert != nil -} - -// tlscert creates a tls.Certificate from s.key and s.cert. -// Callers should wrap it in s.RLock() and s.RUnlock(). -func (s *certState) tlscert() (*tls.Certificate, error) { - if s.key == nil { - return nil, errors.New("acme/autocert: missing signer") - } - if len(s.cert) == 0 { - return nil, errors.New("acme/autocert: missing certificate") - } - return &tls.Certificate{ - PrivateKey: s.key, - Certificate: s.cert, - Leaf: s.leaf, - }, nil -} - -// certRequest generates a CSR for the given common name cn and optional SANs. -func certRequest(key crypto.Signer, cn string, ext []pkix.Extension, san ...string) ([]byte, error) { - req := &x509.CertificateRequest{ - Subject: pkix.Name{CommonName: cn}, - DNSNames: san, - ExtraExtensions: ext, - } - return x509.CreateCertificateRequest(rand.Reader, req, key) -} - -// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates -// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys. -// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. -// -// Inspired by parsePrivateKey in crypto/tls/tls.go. -func parsePrivateKey(der []byte) (crypto.Signer, error) { - if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { - return key, nil - } - if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { - switch key := key.(type) { - case *rsa.PrivateKey: - return key, nil - case *ecdsa.PrivateKey: - return key, nil - default: - return nil, errors.New("acme/autocert: unknown private key type in PKCS#8 wrapping") - } - } - if key, err := x509.ParseECPrivateKey(der); err == nil { - return key, nil - } - - return nil, errors.New("acme/autocert: failed to parse private key") -} - -// validCert parses a cert chain provided as der argument and verifies the leaf and der[0] -// correspond to the private key, the domain and key type match, and expiration dates -// are valid. It doesn't do any revocation checking. -// -// The returned value is the verified leaf cert. -func validCert(ck certKey, der [][]byte, key crypto.Signer) (leaf *x509.Certificate, err error) { - // parse public part(s) - var n int - for _, b := range der { - n += len(b) - } - pub := make([]byte, n) - n = 0 - for _, b := range der { - n += copy(pub[n:], b) - } - x509Cert, err := x509.ParseCertificates(pub) - if err != nil || len(x509Cert) == 0 { - return nil, errors.New("acme/autocert: no public key found") - } - // verify the leaf is not expired and matches the domain name - leaf = x509Cert[0] - now := timeNow() - if now.Before(leaf.NotBefore) { - return nil, errors.New("acme/autocert: certificate is not valid yet") - } - if now.After(leaf.NotAfter) { - return nil, errors.New("acme/autocert: expired certificate") - } - if err := leaf.VerifyHostname(ck.domain); err != nil { - return nil, err - } - // ensure the leaf corresponds to the private key and matches the certKey type - switch pub := leaf.PublicKey.(type) { - case *rsa.PublicKey: - prv, ok := key.(*rsa.PrivateKey) - if !ok { - return nil, errors.New("acme/autocert: private key type does not match public key type") - } - if pub.N.Cmp(prv.N) != 0 { - return nil, errors.New("acme/autocert: private key does not match public key") - } - if !ck.isRSA && !ck.isToken { - return nil, errors.New("acme/autocert: key type does not match expected value") - } - case *ecdsa.PublicKey: - prv, ok := key.(*ecdsa.PrivateKey) - if !ok { - return nil, errors.New("acme/autocert: private key type does not match public key type") - } - if pub.X.Cmp(prv.X) != 0 || pub.Y.Cmp(prv.Y) != 0 { - return nil, errors.New("acme/autocert: private key does not match public key") - } - if ck.isRSA && !ck.isToken { - return nil, errors.New("acme/autocert: key type does not match expected value") - } - default: - return nil, errors.New("acme/autocert: unknown public key algorithm") - } - return leaf, nil -} - -type lockedMathRand struct { - sync.Mutex - rnd *mathrand.Rand -} - -func (r *lockedMathRand) int63n(max int64) int64 { - r.Lock() - n := r.rnd.Int63n(max) - r.Unlock() - return n -} - -// For easier testing. -var ( - timeNow = time.Now - - // Called when a state is removed. - testDidRemoveState = func(certKey) {} -) diff --git a/vendor/golang.org/x/crypto/acme/autocert/cache.go b/vendor/golang.org/x/crypto/acme/autocert/cache.go deleted file mode 100644 index aa9aa845c8..0000000000 --- a/vendor/golang.org/x/crypto/acme/autocert/cache.go +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package autocert - -import ( - "context" - "errors" - "io/ioutil" - "os" - "path/filepath" -) - -// ErrCacheMiss is returned when a certificate is not found in cache. -var ErrCacheMiss = errors.New("acme/autocert: certificate cache miss") - -// Cache is used by Manager to store and retrieve previously obtained certificates -// and other account data as opaque blobs. -// -// Cache implementations should not rely on the key naming pattern. Keys can -// include any printable ASCII characters, except the following: \/:*?"<>| -type Cache interface { - // Get returns a certificate data for the specified key. - // If there's no such key, Get returns ErrCacheMiss. - Get(ctx context.Context, key string) ([]byte, error) - - // Put stores the data in the cache under the specified key. - // Underlying implementations may use any data storage format, - // as long as the reverse operation, Get, results in the original data. - Put(ctx context.Context, key string, data []byte) error - - // Delete removes a certificate data from the cache under the specified key. - // If there's no such key in the cache, Delete returns nil. - Delete(ctx context.Context, key string) error -} - -// DirCache implements Cache using a directory on the local filesystem. -// If the directory does not exist, it will be created with 0700 permissions. -type DirCache string - -// Get reads a certificate data from the specified file name. -func (d DirCache) Get(ctx context.Context, name string) ([]byte, error) { - name = filepath.Join(string(d), name) - var ( - data []byte - err error - done = make(chan struct{}) - ) - go func() { - data, err = ioutil.ReadFile(name) - close(done) - }() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-done: - } - if os.IsNotExist(err) { - return nil, ErrCacheMiss - } - return data, err -} - -// Put writes the certificate data to the specified file name. -// The file will be created with 0600 permissions. -func (d DirCache) Put(ctx context.Context, name string, data []byte) error { - if err := os.MkdirAll(string(d), 0700); err != nil { - return err - } - - done := make(chan struct{}) - var err error - go func() { - defer close(done) - var tmp string - if tmp, err = d.writeTempFile(name, data); err != nil { - return - } - select { - case <-ctx.Done(): - // Don't overwrite the file if the context was canceled. - default: - newName := filepath.Join(string(d), name) - err = os.Rename(tmp, newName) - } - }() - select { - case <-ctx.Done(): - return ctx.Err() - case <-done: - } - return err -} - -// Delete removes the specified file name. -func (d DirCache) Delete(ctx context.Context, name string) error { - name = filepath.Join(string(d), name) - var ( - err error - done = make(chan struct{}) - ) - go func() { - err = os.Remove(name) - close(done) - }() - select { - case <-ctx.Done(): - return ctx.Err() - case <-done: - } - if err != nil && !os.IsNotExist(err) { - return err - } - return nil -} - -// writeTempFile writes b to a temporary file, closes the file and returns its path. -func (d DirCache) writeTempFile(prefix string, b []byte) (string, error) { - // TempFile uses 0600 permissions - f, err := ioutil.TempFile(string(d), prefix) - if err != nil { - return "", err - } - if _, err := f.Write(b); err != nil { - f.Close() - return "", err - } - return f.Name(), f.Close() -} diff --git a/vendor/golang.org/x/crypto/acme/autocert/listener.go b/vendor/golang.org/x/crypto/acme/autocert/listener.go deleted file mode 100644 index 1e069818a5..0000000000 --- a/vendor/golang.org/x/crypto/acme/autocert/listener.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package autocert - -import ( - "crypto/tls" - "log" - "net" - "os" - "path/filepath" - "runtime" - "time" -) - -// NewListener returns a net.Listener that listens on the standard TLS -// port (443) on all interfaces and returns *tls.Conn connections with -// LetsEncrypt certificates for the provided domain or domains. -// -// It enables one-line HTTPS servers: -// -// log.Fatal(http.Serve(autocert.NewListener("example.com"), handler)) -// -// NewListener is a convenience function for a common configuration. -// More complex or custom configurations can use the autocert.Manager -// type instead. -// -// Use of this function implies acceptance of the LetsEncrypt Terms of -// Service. If domains is not empty, the provided domains are passed -// to HostWhitelist. If domains is empty, the listener will do -// LetsEncrypt challenges for any requested domain, which is not -// recommended. -// -// Certificates are cached in a "golang-autocert" directory under an -// operating system-specific cache or temp directory. This may not -// be suitable for servers spanning multiple machines. -// -// The returned listener uses a *tls.Config that enables HTTP/2, and -// should only be used with servers that support HTTP/2. -// -// The returned Listener also enables TCP keep-alives on the accepted -// connections. The returned *tls.Conn are returned before their TLS -// handshake has completed. -func NewListener(domains ...string) net.Listener { - m := &Manager{ - Prompt: AcceptTOS, - } - if len(domains) > 0 { - m.HostPolicy = HostWhitelist(domains...) - } - dir := cacheDir() - if err := os.MkdirAll(dir, 0700); err != nil { - log.Printf("warning: autocert.NewListener not using a cache: %v", err) - } else { - m.Cache = DirCache(dir) - } - return m.Listener() -} - -// Listener listens on the standard TLS port (443) on all interfaces -// and returns a net.Listener returning *tls.Conn connections. -// -// The returned listener uses a *tls.Config that enables HTTP/2, and -// should only be used with servers that support HTTP/2. -// -// The returned Listener also enables TCP keep-alives on the accepted -// connections. The returned *tls.Conn are returned before their TLS -// handshake has completed. -// -// Unlike NewListener, it is the caller's responsibility to initialize -// the Manager m's Prompt, Cache, HostPolicy, and other desired options. -func (m *Manager) Listener() net.Listener { - ln := &listener{ - m: m, - conf: m.TLSConfig(), - } - ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443") - return ln -} - -type listener struct { - m *Manager - conf *tls.Config - - tcpListener net.Listener - tcpListenErr error -} - -func (ln *listener) Accept() (net.Conn, error) { - if ln.tcpListenErr != nil { - return nil, ln.tcpListenErr - } - conn, err := ln.tcpListener.Accept() - if err != nil { - return nil, err - } - tcpConn := conn.(*net.TCPConn) - - // Because Listener is a convenience function, help out with - // this too. This is not possible for the caller to set once - // we return a *tcp.Conn wrapping an inaccessible net.Conn. - // If callers don't want this, they can do things the manual - // way and tweak as needed. But this is what net/http does - // itself, so copy that. If net/http changes, we can change - // here too. - tcpConn.SetKeepAlive(true) - tcpConn.SetKeepAlivePeriod(3 * time.Minute) - - return tls.Server(tcpConn, ln.conf), nil -} - -func (ln *listener) Addr() net.Addr { - if ln.tcpListener != nil { - return ln.tcpListener.Addr() - } - // net.Listen failed. Return something non-nil in case callers - // call Addr before Accept: - return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443} -} - -func (ln *listener) Close() error { - if ln.tcpListenErr != nil { - return ln.tcpListenErr - } - return ln.tcpListener.Close() -} - -func homeDir() string { - if runtime.GOOS == "windows" { - return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") - } - if h := os.Getenv("HOME"); h != "" { - return h - } - return "/" -} - -func cacheDir() string { - const base = "golang-autocert" - switch runtime.GOOS { - case "darwin": - return filepath.Join(homeDir(), "Library", "Caches", base) - case "windows": - for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} { - if v := os.Getenv(ev); v != "" { - return filepath.Join(v, base) - } - } - // Worst case: - return filepath.Join(homeDir(), base) - } - if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" { - return filepath.Join(xdg, base) - } - return filepath.Join(homeDir(), ".cache", base) -} diff --git a/vendor/golang.org/x/crypto/acme/autocert/renewal.go b/vendor/golang.org/x/crypto/acme/autocert/renewal.go deleted file mode 100644 index ef3e44e199..0000000000 --- a/vendor/golang.org/x/crypto/acme/autocert/renewal.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package autocert - -import ( - "context" - "crypto" - "sync" - "time" -) - -// renewJitter is the maximum deviation from Manager.RenewBefore. -const renewJitter = time.Hour - -// domainRenewal tracks the state used by the periodic timers -// renewing a single domain's cert. -type domainRenewal struct { - m *Manager - ck certKey - key crypto.Signer - - timerMu sync.Mutex - timer *time.Timer -} - -// start starts a cert renewal timer at the time -// defined by the certificate expiration time exp. -// -// If the timer is already started, calling start is a noop. -func (dr *domainRenewal) start(exp time.Time) { - dr.timerMu.Lock() - defer dr.timerMu.Unlock() - if dr.timer != nil { - return - } - dr.timer = time.AfterFunc(dr.next(exp), dr.renew) -} - -// stop stops the cert renewal timer. -// If the timer is already stopped, calling stop is a noop. -func (dr *domainRenewal) stop() { - dr.timerMu.Lock() - defer dr.timerMu.Unlock() - if dr.timer == nil { - return - } - dr.timer.Stop() - dr.timer = nil -} - -// renew is called periodically by a timer. -// The first renew call is kicked off by dr.start. -func (dr *domainRenewal) renew() { - dr.timerMu.Lock() - defer dr.timerMu.Unlock() - if dr.timer == nil { - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer cancel() - // TODO: rotate dr.key at some point? - next, err := dr.do(ctx) - if err != nil { - next = renewJitter / 2 - next += time.Duration(pseudoRand.int63n(int64(next))) - } - dr.timer = time.AfterFunc(next, dr.renew) - testDidRenewLoop(next, err) -} - -// updateState locks and replaces the relevant Manager.state item with the given -// state. It additionally updates dr.key with the given state's key. -func (dr *domainRenewal) updateState(state *certState) { - dr.m.stateMu.Lock() - defer dr.m.stateMu.Unlock() - dr.key = state.key - dr.m.state[dr.ck] = state -} - -// do is similar to Manager.createCert but it doesn't lock a Manager.state item. -// Instead, it requests a new certificate independently and, upon success, -// replaces dr.m.state item with a new one and updates cache for the given domain. -// -// It may lock and update the Manager.state if the expiration date of the currently -// cached cert is far enough in the future. -// -// The returned value is a time interval after which the renewal should occur again. -func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { - // a race is likely unavoidable in a distributed environment - // but we try nonetheless - if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil { - next := dr.next(tlscert.Leaf.NotAfter) - if next > dr.m.renewBefore()+renewJitter { - signer, ok := tlscert.PrivateKey.(crypto.Signer) - if ok { - state := &certState{ - key: signer, - cert: tlscert.Certificate, - leaf: tlscert.Leaf, - } - dr.updateState(state) - return next, nil - } - } - } - - der, leaf, err := dr.m.authorizedCert(ctx, dr.key, dr.ck) - if err != nil { - return 0, err - } - state := &certState{ - key: dr.key, - cert: der, - leaf: leaf, - } - tlscert, err := state.tlscert() - if err != nil { - return 0, err - } - if err := dr.m.cachePut(ctx, dr.ck, tlscert); err != nil { - return 0, err - } - dr.updateState(state) - return dr.next(leaf.NotAfter), nil -} - -func (dr *domainRenewal) next(expiry time.Time) time.Duration { - d := expiry.Sub(timeNow()) - dr.m.renewBefore() - // add a bit of randomness to renew deadline - n := pseudoRand.int63n(int64(renewJitter)) - d -= time.Duration(n) - if d < 0 { - return 0 - } - return d -} - -var testDidRenewLoop = func(next time.Duration, err error) {} diff --git a/vendor/golang.org/x/crypto/acme/http.go b/vendor/golang.org/x/crypto/acme/http.go deleted file mode 100644 index a43ce6a5fe..0000000000 --- a/vendor/golang.org/x/crypto/acme/http.go +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package acme - -import ( - "bytes" - "context" - "crypto" - "crypto/rand" - "encoding/json" - "fmt" - "io/ioutil" - "math/big" - "net/http" - "strconv" - "strings" - "time" -) - -// retryTimer encapsulates common logic for retrying unsuccessful requests. -// It is not safe for concurrent use. -type retryTimer struct { - // backoffFn provides backoff delay sequence for retries. - // See Client.RetryBackoff doc comment. - backoffFn func(n int, r *http.Request, res *http.Response) time.Duration - // n is the current retry attempt. - n int -} - -func (t *retryTimer) inc() { - t.n++ -} - -// backoff pauses the current goroutine as described in Client.RetryBackoff. -func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Response) error { - d := t.backoffFn(t.n, r, res) - if d <= 0 { - return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", r.URL, t.n) - } - wakeup := time.NewTimer(d) - defer wakeup.Stop() - select { - case <-ctx.Done(): - return ctx.Err() - case <-wakeup.C: - return nil - } -} - -func (c *Client) retryTimer() *retryTimer { - f := c.RetryBackoff - if f == nil { - f = defaultBackoff - } - return &retryTimer{backoffFn: f} -} - -// defaultBackoff provides default Client.RetryBackoff implementation -// using a truncated exponential backoff algorithm, -// as described in Client.RetryBackoff. -// -// The n argument is always bounded between 1 and 30. -// The returned value is always greater than 0. -func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration { - const max = 10 * time.Second - var jitter time.Duration - if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil { - // Set the minimum to 1ms to avoid a case where - // an invalid Retry-After value is parsed into 0 below, - // resulting in the 0 returned value which would unintentionally - // stop the retries. - jitter = (1 + time.Duration(x.Int64())) * time.Millisecond - } - if v, ok := res.Header["Retry-After"]; ok { - return retryAfter(v[0]) + jitter - } - - if n < 1 { - n = 1 - } - if n > 30 { - n = 30 - } - d := time.Duration(1< max { - return max - } - return d -} - -// retryAfter parses a Retry-After HTTP header value, -// trying to convert v into an int (seconds) or use http.ParseTime otherwise. -// It returns zero value if v cannot be parsed. -func retryAfter(v string) time.Duration { - if i, err := strconv.Atoi(v); err == nil { - return time.Duration(i) * time.Second - } - t, err := http.ParseTime(v) - if err != nil { - return 0 - } - return t.Sub(timeNow()) -} - -// resOkay is a function that reports whether the provided response is okay. -// It is expected to keep the response body unread. -type resOkay func(*http.Response) bool - -// wantStatus returns a function which reports whether the code -// matches the status code of a response. -func wantStatus(codes ...int) resOkay { - return func(res *http.Response) bool { - for _, code := range codes { - if code == res.StatusCode { - return true - } - } - return false - } -} - -// get issues an unsigned GET request to the specified URL. -// It returns a non-error value only when ok reports true. -// -// get retries unsuccessful attempts according to c.RetryBackoff -// until the context is done or a non-retriable error is received. -func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) { - retry := c.retryTimer() - for { - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return nil, err - } - res, err := c.doNoRetry(ctx, req) - switch { - case err != nil: - return nil, err - case ok(res): - return res, nil - case isRetriable(res.StatusCode): - retry.inc() - resErr := responseError(res) - res.Body.Close() - // Ignore the error value from retry.backoff - // and return the one from last retry, as received from the CA. - if retry.backoff(ctx, req, res) != nil { - return nil, resErr - } - default: - defer res.Body.Close() - return nil, responseError(res) - } - } -} - -// post issues a signed POST request in JWS format using the provided key -// to the specified URL. -// It returns a non-error value only when ok reports true. -// -// post retries unsuccessful attempts according to c.RetryBackoff -// until the context is done or a non-retriable error is received. -// It uses postNoRetry to make individual requests. -func (c *Client) post(ctx context.Context, key crypto.Signer, url string, body interface{}, ok resOkay) (*http.Response, error) { - retry := c.retryTimer() - for { - res, req, err := c.postNoRetry(ctx, key, url, body) - if err != nil { - return nil, err - } - if ok(res) { - return res, nil - } - resErr := responseError(res) - res.Body.Close() - switch { - // Check for bad nonce before isRetriable because it may have been returned - // with an unretriable response code such as 400 Bad Request. - case isBadNonce(resErr): - // Consider any previously stored nonce values to be invalid. - c.clearNonces() - case !isRetriable(res.StatusCode): - return nil, resErr - } - retry.inc() - // Ignore the error value from retry.backoff - // and return the one from last retry, as received from the CA. - if err := retry.backoff(ctx, req, res); err != nil { - return nil, resErr - } - } -} - -// postNoRetry signs the body with the given key and POSTs it to the provided url. -// The body argument must be JSON-serializable. -// It is used by c.post to retry unsuccessful attempts. -func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, *http.Request, error) { - nonce, err := c.popNonce(ctx, url) - if err != nil { - return nil, nil, err - } - b, err := jwsEncodeJSON(body, key, nonce) - if err != nil { - return nil, nil, err - } - req, err := http.NewRequest("POST", url, bytes.NewReader(b)) - if err != nil { - return nil, nil, err - } - req.Header.Set("Content-Type", "application/jose+json") - res, err := c.doNoRetry(ctx, req) - if err != nil { - return nil, nil, err - } - c.addNonce(res.Header) - return res, req, nil -} - -// doNoRetry issues a request req, replacing its context (if any) with ctx. -func (c *Client) doNoRetry(ctx context.Context, req *http.Request) (*http.Response, error) { - res, err := c.httpClient().Do(req.WithContext(ctx)) - if err != nil { - select { - case <-ctx.Done(): - // Prefer the unadorned context error. - // (The acme package had tests assuming this, previously from ctxhttp's - // behavior, predating net/http supporting contexts natively) - // TODO(bradfitz): reconsider this in the future. But for now this - // requires no test updates. - return nil, ctx.Err() - default: - return nil, err - } - } - return res, nil -} - -func (c *Client) httpClient() *http.Client { - if c.HTTPClient != nil { - return c.HTTPClient - } - return http.DefaultClient -} - -// isBadNonce reports whether err is an ACME "badnonce" error. -func isBadNonce(err error) bool { - // According to the spec badNonce is urn:ietf:params:acme:error:badNonce. - // However, ACME servers in the wild return their versions of the error. - // See https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4 - // and https://github.com/letsencrypt/boulder/blob/0e07eacb/docs/acme-divergences.md#section-66. - ae, ok := err.(*Error) - return ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") -} - -// isRetriable reports whether a request can be retried -// based on the response status code. -// -// Note that a "bad nonce" error is returned with a non-retriable 400 Bad Request code. -// Callers should parse the response and check with isBadNonce. -func isRetriable(code int) bool { - return code <= 399 || code >= 500 || code == http.StatusTooManyRequests -} - -// responseError creates an error of Error type from resp. -func responseError(resp *http.Response) error { - // don't care if ReadAll returns an error: - // json.Unmarshal will fail in that case anyway - b, _ := ioutil.ReadAll(resp.Body) - e := &wireError{Status: resp.StatusCode} - if err := json.Unmarshal(b, e); err != nil { - // this is not a regular error response: - // populate detail with anything we received, - // e.Status will already contain HTTP response code value - e.Detail = string(b) - if e.Detail == "" { - e.Detail = resp.Status - } - } - return e.error(resp.Header) -} diff --git a/vendor/golang.org/x/crypto/acme/jws.go b/vendor/golang.org/x/crypto/acme/jws.go deleted file mode 100644 index 6cbca25de9..0000000000 --- a/vendor/golang.org/x/crypto/acme/jws.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package acme - -import ( - "crypto" - "crypto/ecdsa" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - _ "crypto/sha512" // need for EC keys - "encoding/base64" - "encoding/json" - "fmt" - "math/big" -) - -// jwsEncodeJSON signs claimset using provided key and a nonce. -// The result is serialized in JSON format. -// See https://tools.ietf.org/html/rfc7515#section-7. -func jwsEncodeJSON(claimset interface{}, key crypto.Signer, nonce string) ([]byte, error) { - jwk, err := jwkEncode(key.Public()) - if err != nil { - return nil, err - } - alg, sha := jwsHasher(key) - if alg == "" || !sha.Available() { - return nil, ErrUnsupportedKey - } - phead := fmt.Sprintf(`{"alg":%q,"jwk":%s,"nonce":%q}`, alg, jwk, nonce) - phead = base64.RawURLEncoding.EncodeToString([]byte(phead)) - cs, err := json.Marshal(claimset) - if err != nil { - return nil, err - } - payload := base64.RawURLEncoding.EncodeToString(cs) - hash := sha.New() - hash.Write([]byte(phead + "." + payload)) - sig, err := jwsSign(key, sha, hash.Sum(nil)) - if err != nil { - return nil, err - } - - enc := struct { - Protected string `json:"protected"` - Payload string `json:"payload"` - Sig string `json:"signature"` - }{ - Protected: phead, - Payload: payload, - Sig: base64.RawURLEncoding.EncodeToString(sig), - } - return json.Marshal(&enc) -} - -// jwkEncode encodes public part of an RSA or ECDSA key into a JWK. -// The result is also suitable for creating a JWK thumbprint. -// https://tools.ietf.org/html/rfc7517 -func jwkEncode(pub crypto.PublicKey) (string, error) { - switch pub := pub.(type) { - case *rsa.PublicKey: - // https://tools.ietf.org/html/rfc7518#section-6.3.1 - n := pub.N - e := big.NewInt(int64(pub.E)) - // Field order is important. - // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. - return fmt.Sprintf(`{"e":"%s","kty":"RSA","n":"%s"}`, - base64.RawURLEncoding.EncodeToString(e.Bytes()), - base64.RawURLEncoding.EncodeToString(n.Bytes()), - ), nil - case *ecdsa.PublicKey: - // https://tools.ietf.org/html/rfc7518#section-6.2.1 - p := pub.Curve.Params() - n := p.BitSize / 8 - if p.BitSize%8 != 0 { - n++ - } - x := pub.X.Bytes() - if n > len(x) { - x = append(make([]byte, n-len(x)), x...) - } - y := pub.Y.Bytes() - if n > len(y) { - y = append(make([]byte, n-len(y)), y...) - } - // Field order is important. - // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. - return fmt.Sprintf(`{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`, - p.Name, - base64.RawURLEncoding.EncodeToString(x), - base64.RawURLEncoding.EncodeToString(y), - ), nil - } - return "", ErrUnsupportedKey -} - -// jwsSign signs the digest using the given key. -// It returns ErrUnsupportedKey if the key type is unknown. -// The hash is used only for RSA keys. -func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) { - switch key := key.(type) { - case *rsa.PrivateKey: - return key.Sign(rand.Reader, digest, hash) - case *ecdsa.PrivateKey: - r, s, err := ecdsa.Sign(rand.Reader, key, digest) - if err != nil { - return nil, err - } - rb, sb := r.Bytes(), s.Bytes() - size := key.Params().BitSize / 8 - if size%8 > 0 { - size++ - } - sig := make([]byte, size*2) - copy(sig[size-len(rb):], rb) - copy(sig[size*2-len(sb):], sb) - return sig, nil - } - return nil, ErrUnsupportedKey -} - -// jwsHasher indicates suitable JWS algorithm name and a hash function -// to use for signing a digest with the provided key. -// It returns ("", 0) if the key is not supported. -func jwsHasher(key crypto.Signer) (string, crypto.Hash) { - switch key := key.(type) { - case *rsa.PrivateKey: - return "RS256", crypto.SHA256 - case *ecdsa.PrivateKey: - switch key.Params().Name { - case "P-256": - return "ES256", crypto.SHA256 - case "P-384": - return "ES384", crypto.SHA384 - case "P-521": - return "ES512", crypto.SHA512 - } - } - return "", 0 -} - -// JWKThumbprint creates a JWK thumbprint out of pub -// as specified in https://tools.ietf.org/html/rfc7638. -func JWKThumbprint(pub crypto.PublicKey) (string, error) { - jwk, err := jwkEncode(pub) - if err != nil { - return "", err - } - b := sha256.Sum256([]byte(jwk)) - return base64.RawURLEncoding.EncodeToString(b[:]), nil -} diff --git a/vendor/golang.org/x/crypto/acme/types.go b/vendor/golang.org/x/crypto/acme/types.go deleted file mode 100644 index 54792c0650..0000000000 --- a/vendor/golang.org/x/crypto/acme/types.go +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package acme - -import ( - "crypto" - "crypto/x509" - "errors" - "fmt" - "net/http" - "strings" - "time" -) - -// ACME server response statuses used to describe Authorization and Challenge states. -const ( - StatusUnknown = "unknown" - StatusPending = "pending" - StatusProcessing = "processing" - StatusValid = "valid" - StatusInvalid = "invalid" - StatusRevoked = "revoked" -) - -// CRLReasonCode identifies the reason for a certificate revocation. -type CRLReasonCode int - -// CRL reason codes as defined in RFC 5280. -const ( - CRLReasonUnspecified CRLReasonCode = 0 - CRLReasonKeyCompromise CRLReasonCode = 1 - CRLReasonCACompromise CRLReasonCode = 2 - CRLReasonAffiliationChanged CRLReasonCode = 3 - CRLReasonSuperseded CRLReasonCode = 4 - CRLReasonCessationOfOperation CRLReasonCode = 5 - CRLReasonCertificateHold CRLReasonCode = 6 - CRLReasonRemoveFromCRL CRLReasonCode = 8 - CRLReasonPrivilegeWithdrawn CRLReasonCode = 9 - CRLReasonAACompromise CRLReasonCode = 10 -) - -// ErrUnsupportedKey is returned when an unsupported key type is encountered. -var ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported") - -// Error is an ACME error, defined in Problem Details for HTTP APIs doc -// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem. -type Error struct { - // StatusCode is The HTTP status code generated by the origin server. - StatusCode int - // ProblemType is a URI reference that identifies the problem type, - // typically in a "urn:acme:error:xxx" form. - ProblemType string - // Detail is a human-readable explanation specific to this occurrence of the problem. - Detail string - // Header is the original server error response headers. - // It may be nil. - Header http.Header -} - -func (e *Error) Error() string { - return fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail) -} - -// AuthorizationError indicates that an authorization for an identifier -// did not succeed. -// It contains all errors from Challenge items of the failed Authorization. -type AuthorizationError struct { - // URI uniquely identifies the failed Authorization. - URI string - - // Identifier is an AuthzID.Value of the failed Authorization. - Identifier string - - // Errors is a collection of non-nil error values of Challenge items - // of the failed Authorization. - Errors []error -} - -func (a *AuthorizationError) Error() string { - e := make([]string, len(a.Errors)) - for i, err := range a.Errors { - e[i] = err.Error() - } - return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; ")) -} - -// RateLimit reports whether err represents a rate limit error and -// any Retry-After duration returned by the server. -// -// See the following for more details on rate limiting: -// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6 -func RateLimit(err error) (time.Duration, bool) { - e, ok := err.(*Error) - if !ok { - return 0, false - } - // Some CA implementations may return incorrect values. - // Use case-insensitive comparison. - if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") { - return 0, false - } - if e.Header == nil { - return 0, true - } - return retryAfter(e.Header.Get("Retry-After")), true -} - -// Account is a user account. It is associated with a private key. -type Account struct { - // URI is the account unique ID, which is also a URL used to retrieve - // account data from the CA. - URI string - - // Contact is a slice of contact info used during registration. - Contact []string - - // The terms user has agreed to. - // A value not matching CurrentTerms indicates that the user hasn't agreed - // to the actual Terms of Service of the CA. - AgreedTerms string - - // Actual terms of a CA. - CurrentTerms string - - // Authz is the authorization URL used to initiate a new authz flow. - Authz string - - // Authorizations is a URI from which a list of authorizations - // granted to this account can be fetched via a GET request. - Authorizations string - - // Certificates is a URI from which a list of certificates - // issued for this account can be fetched via a GET request. - Certificates string -} - -// Directory is ACME server discovery data. -type Directory struct { - // RegURL is an account endpoint URL, allowing for creating new - // and modifying existing accounts. - RegURL string - - // AuthzURL is used to initiate Identifier Authorization flow. - AuthzURL string - - // CertURL is a new certificate issuance endpoint URL. - CertURL string - - // RevokeURL is used to initiate a certificate revocation flow. - RevokeURL string - - // Term is a URI identifying the current terms of service. - Terms string - - // Website is an HTTP or HTTPS URL locating a website - // providing more information about the ACME server. - Website string - - // CAA consists of lowercase hostname elements, which the ACME server - // recognises as referring to itself for the purposes of CAA record validation - // as defined in RFC6844. - CAA []string -} - -// Challenge encodes a returned CA challenge. -// Its Error field may be non-nil if the challenge is part of an Authorization -// with StatusInvalid. -type Challenge struct { - // Type is the challenge type, e.g. "http-01", "tls-sni-02", "dns-01". - Type string - - // URI is where a challenge response can be posted to. - URI string - - // Token is a random value that uniquely identifies the challenge. - Token string - - // Status identifies the status of this challenge. - Status string - - // Error indicates the reason for an authorization failure - // when this challenge was used. - // The type of a non-nil value is *Error. - Error error -} - -// Authorization encodes an authorization response. -type Authorization struct { - // URI uniquely identifies a authorization. - URI string - - // Status identifies the status of an authorization. - Status string - - // Identifier is what the account is authorized to represent. - Identifier AuthzID - - // Challenges that the client needs to fulfill in order to prove possession - // of the identifier (for pending authorizations). - // For final authorizations, the challenges that were used. - Challenges []*Challenge - - // A collection of sets of challenges, each of which would be sufficient - // to prove possession of the identifier. - // Clients must complete a set of challenges that covers at least one set. - // Challenges are identified by their indices in the challenges array. - // If this field is empty, the client needs to complete all challenges. - Combinations [][]int -} - -// AuthzID is an identifier that an account is authorized to represent. -type AuthzID struct { - Type string // The type of identifier, e.g. "dns". - Value string // The identifier itself, e.g. "example.org". -} - -// wireAuthz is ACME JSON representation of Authorization objects. -type wireAuthz struct { - Status string - Challenges []wireChallenge - Combinations [][]int - Identifier struct { - Type string - Value string - } -} - -func (z *wireAuthz) authorization(uri string) *Authorization { - a := &Authorization{ - URI: uri, - Status: z.Status, - Identifier: AuthzID{Type: z.Identifier.Type, Value: z.Identifier.Value}, - Combinations: z.Combinations, // shallow copy - Challenges: make([]*Challenge, len(z.Challenges)), - } - for i, v := range z.Challenges { - a.Challenges[i] = v.challenge() - } - return a -} - -func (z *wireAuthz) error(uri string) *AuthorizationError { - err := &AuthorizationError{ - URI: uri, - Identifier: z.Identifier.Value, - } - for _, raw := range z.Challenges { - if raw.Error != nil { - err.Errors = append(err.Errors, raw.Error.error(nil)) - } - } - return err -} - -// wireChallenge is ACME JSON challenge representation. -type wireChallenge struct { - URI string `json:"uri"` - Type string - Token string - Status string - Error *wireError -} - -func (c *wireChallenge) challenge() *Challenge { - v := &Challenge{ - URI: c.URI, - Type: c.Type, - Token: c.Token, - Status: c.Status, - } - if v.Status == "" { - v.Status = StatusPending - } - if c.Error != nil { - v.Error = c.Error.error(nil) - } - return v -} - -// wireError is a subset of fields of the Problem Details object -// as described in https://tools.ietf.org/html/rfc7807#section-3.1. -type wireError struct { - Status int - Type string - Detail string -} - -func (e *wireError) error(h http.Header) *Error { - return &Error{ - StatusCode: e.Status, - ProblemType: e.Type, - Detail: e.Detail, - Header: h, - } -} - -// CertOption is an optional argument type for the TLS ChallengeCert methods for -// customizing a temporary certificate for TLS-based challenges. -type CertOption interface { - privateCertOpt() -} - -// WithKey creates an option holding a private/public key pair. -// The private part signs a certificate, and the public part represents the signee. -func WithKey(key crypto.Signer) CertOption { - return &certOptKey{key} -} - -type certOptKey struct { - key crypto.Signer -} - -func (*certOptKey) privateCertOpt() {} - -// WithTemplate creates an option for specifying a certificate template. -// See x509.CreateCertificate for template usage details. -// -// In TLS ChallengeCert methods, the template is also used as parent, -// resulting in a self-signed certificate. -// The DNSNames field of t is always overwritten for tls-sni challenge certs. -func WithTemplate(t *x509.Certificate) CertOption { - return (*certOptTemplate)(t) -} - -type certOptTemplate x509.Certificate - -func (*certOptTemplate) privateCertOpt() {} diff --git a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go b/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go deleted file mode 100644 index 593f653008..0000000000 --- a/vendor/golang.org/x/crypto/pbkdf2/pbkdf2.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -/* -Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC -2898 / PKCS #5 v2.0. - -A key derivation function is useful when encrypting data based on a password -or any other not-fully-random data. It uses a pseudorandom function to derive -a secure encryption key based on the password. - -While v2.0 of the standard defines only one pseudorandom function to use, -HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved -Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To -choose, you can pass the `New` functions from the different SHA packages to -pbkdf2.Key. -*/ -package pbkdf2 // import "golang.org/x/crypto/pbkdf2" - -import ( - "crypto/hmac" - "hash" -) - -// Key derives a key from the password, salt and iteration count, returning a -// []byte of length keylen that can be used as cryptographic key. The key is -// derived based on the method described as PBKDF2 with the HMAC variant using -// the supplied hash function. -// -// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you -// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by -// doing: -// -// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) -// -// Remember to get a good random salt. At least 8 bytes is recommended by the -// RFC. -// -// Using a higher iteration count will increase the cost of an exhaustive -// search but will also make derivation proportionally slower. -func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte { - prf := hmac.New(h, password) - hashLen := prf.Size() - numBlocks := (keyLen + hashLen - 1) / hashLen - - var buf [4]byte - dk := make([]byte, 0, numBlocks*hashLen) - U := make([]byte, hashLen) - for block := 1; block <= numBlocks; block++ { - // N.B.: || means concatenation, ^ means XOR - // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter - // U_1 = PRF(password, salt || uint(i)) - prf.Reset() - prf.Write(salt) - buf[0] = byte(block >> 24) - buf[1] = byte(block >> 16) - buf[2] = byte(block >> 8) - buf[3] = byte(block) - prf.Write(buf[:4]) - dk = prf.Sum(dk) - T := dk[len(dk)-hashLen:] - copy(U, T) - - // U_n = PRF(password, U_(n-1)) - for n := 2; n <= iter; n++ { - prf.Reset() - prf.Write(U) - U = U[:0] - U = prf.Sum(U) - for x := range U { - T[x] ^= U[x] - } - } - } - return dk[:keyLen] -} diff --git a/vendor/gopkg.in/yaml.v2/LICENSE b/vendor/gopkg.in/yaml.v2/LICENSE deleted file mode 100644 index 8dada3edaf..0000000000 --- a/vendor/gopkg.in/yaml.v2/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/vendor/gopkg.in/yaml.v2/LICENSE.libyaml b/vendor/gopkg.in/yaml.v2/LICENSE.libyaml deleted file mode 100644 index 8da58fbf6f..0000000000 --- a/vendor/gopkg.in/yaml.v2/LICENSE.libyaml +++ /dev/null @@ -1,31 +0,0 @@ -The following files were ported to Go from C files of libyaml, and thus -are still covered by their original copyright and license: - - apic.go - emitterc.go - parserc.go - readerc.go - scannerc.go - writerc.go - yamlh.go - yamlprivateh.go - -Copyright (c) 2006 Kirill Simonov - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/vendor/gopkg.in/yaml.v2/README.md b/vendor/gopkg.in/yaml.v2/README.md deleted file mode 100644 index b50c6e8775..0000000000 --- a/vendor/gopkg.in/yaml.v2/README.md +++ /dev/null @@ -1,133 +0,0 @@ -# YAML support for the Go language - -Introduction ------------- - -The yaml package enables Go programs to comfortably encode and decode YAML -values. It was developed within [Canonical](https://www.canonical.com) as -part of the [juju](https://juju.ubuntu.com) project, and is based on a -pure Go port of the well-known [libyaml](http://pyyaml.org/wiki/LibYAML) -C library to parse and generate YAML data quickly and reliably. - -Compatibility -------------- - -The yaml package supports most of YAML 1.1 and 1.2, including support for -anchors, tags, map merging, etc. Multi-document unmarshalling is not yet -implemented, and base-60 floats from YAML 1.1 are purposefully not -supported since they're a poor design and are gone in YAML 1.2. - -Installation and usage ----------------------- - -The import path for the package is *gopkg.in/yaml.v2*. - -To install it, run: - - go get gopkg.in/yaml.v2 - -API documentation ------------------ - -If opened in a browser, the import path itself leads to the API documentation: - - * [https://gopkg.in/yaml.v2](https://gopkg.in/yaml.v2) - -API stability -------------- - -The package API for yaml v2 will remain stable as described in [gopkg.in](https://gopkg.in). - - -License -------- - -The yaml package is licensed under the Apache License 2.0. Please see the LICENSE file for details. - - -Example -------- - -```Go -package main - -import ( - "fmt" - "log" - - "gopkg.in/yaml.v2" -) - -var data = ` -a: Easy! -b: - c: 2 - d: [3, 4] -` - -// Note: struct fields must be public in order for unmarshal to -// correctly populate the data. -type T struct { - A string - B struct { - RenamedC int `yaml:"c"` - D []int `yaml:",flow"` - } -} - -func main() { - t := T{} - - err := yaml.Unmarshal([]byte(data), &t) - if err != nil { - log.Fatalf("error: %v", err) - } - fmt.Printf("--- t:\n%v\n\n", t) - - d, err := yaml.Marshal(&t) - if err != nil { - log.Fatalf("error: %v", err) - } - fmt.Printf("--- t dump:\n%s\n\n", string(d)) - - m := make(map[interface{}]interface{}) - - err = yaml.Unmarshal([]byte(data), &m) - if err != nil { - log.Fatalf("error: %v", err) - } - fmt.Printf("--- m:\n%v\n\n", m) - - d, err = yaml.Marshal(&m) - if err != nil { - log.Fatalf("error: %v", err) - } - fmt.Printf("--- m dump:\n%s\n\n", string(d)) -} -``` - -This example will generate the following output: - -``` ---- t: -{Easy! {2 [3 4]}} - ---- t dump: -a: Easy! -b: - c: 2 - d: [3, 4] - - ---- m: -map[a:Easy! b:map[c:2 d:[3 4]]] - ---- m dump: -a: Easy! -b: - c: 2 - d: - - 3 - - 4 -``` - diff --git a/vendor/gopkg.in/yaml.v2/apic.go b/vendor/gopkg.in/yaml.v2/apic.go deleted file mode 100644 index 1f7e87e672..0000000000 --- a/vendor/gopkg.in/yaml.v2/apic.go +++ /dev/null @@ -1,739 +0,0 @@ -package yaml - -import ( - "io" -) - -func yaml_insert_token(parser *yaml_parser_t, pos int, token *yaml_token_t) { - //fmt.Println("yaml_insert_token", "pos:", pos, "typ:", token.typ, "head:", parser.tokens_head, "len:", len(parser.tokens)) - - // Check if we can move the queue at the beginning of the buffer. - if parser.tokens_head > 0 && len(parser.tokens) == cap(parser.tokens) { - if parser.tokens_head != len(parser.tokens) { - copy(parser.tokens, parser.tokens[parser.tokens_head:]) - } - parser.tokens = parser.tokens[:len(parser.tokens)-parser.tokens_head] - parser.tokens_head = 0 - } - parser.tokens = append(parser.tokens, *token) - if pos < 0 { - return - } - copy(parser.tokens[parser.tokens_head+pos+1:], parser.tokens[parser.tokens_head+pos:]) - parser.tokens[parser.tokens_head+pos] = *token -} - -// Create a new parser object. -func yaml_parser_initialize(parser *yaml_parser_t) bool { - *parser = yaml_parser_t{ - raw_buffer: make([]byte, 0, input_raw_buffer_size), - buffer: make([]byte, 0, input_buffer_size), - } - return true -} - -// Destroy a parser object. -func yaml_parser_delete(parser *yaml_parser_t) { - *parser = yaml_parser_t{} -} - -// String read handler. -func yaml_string_read_handler(parser *yaml_parser_t, buffer []byte) (n int, err error) { - if parser.input_pos == len(parser.input) { - return 0, io.EOF - } - n = copy(buffer, parser.input[parser.input_pos:]) - parser.input_pos += n - return n, nil -} - -// Reader read handler. -func yaml_reader_read_handler(parser *yaml_parser_t, buffer []byte) (n int, err error) { - return parser.input_reader.Read(buffer) -} - -// Set a string input. -func yaml_parser_set_input_string(parser *yaml_parser_t, input []byte) { - if parser.read_handler != nil { - panic("must set the input source only once") - } - parser.read_handler = yaml_string_read_handler - parser.input = input - parser.input_pos = 0 -} - -// Set a file input. -func yaml_parser_set_input_reader(parser *yaml_parser_t, r io.Reader) { - if parser.read_handler != nil { - panic("must set the input source only once") - } - parser.read_handler = yaml_reader_read_handler - parser.input_reader = r -} - -// Set the source encoding. -func yaml_parser_set_encoding(parser *yaml_parser_t, encoding yaml_encoding_t) { - if parser.encoding != yaml_ANY_ENCODING { - panic("must set the encoding only once") - } - parser.encoding = encoding -} - -// Create a new emitter object. -func yaml_emitter_initialize(emitter *yaml_emitter_t) { - *emitter = yaml_emitter_t{ - buffer: make([]byte, output_buffer_size), - raw_buffer: make([]byte, 0, output_raw_buffer_size), - states: make([]yaml_emitter_state_t, 0, initial_stack_size), - events: make([]yaml_event_t, 0, initial_queue_size), - } -} - -// Destroy an emitter object. -func yaml_emitter_delete(emitter *yaml_emitter_t) { - *emitter = yaml_emitter_t{} -} - -// String write handler. -func yaml_string_write_handler(emitter *yaml_emitter_t, buffer []byte) error { - *emitter.output_buffer = append(*emitter.output_buffer, buffer...) - return nil -} - -// yaml_writer_write_handler uses emitter.output_writer to write the -// emitted text. -func yaml_writer_write_handler(emitter *yaml_emitter_t, buffer []byte) error { - _, err := emitter.output_writer.Write(buffer) - return err -} - -// Set a string output. -func yaml_emitter_set_output_string(emitter *yaml_emitter_t, output_buffer *[]byte) { - if emitter.write_handler != nil { - panic("must set the output target only once") - } - emitter.write_handler = yaml_string_write_handler - emitter.output_buffer = output_buffer -} - -// Set a file output. -func yaml_emitter_set_output_writer(emitter *yaml_emitter_t, w io.Writer) { - if emitter.write_handler != nil { - panic("must set the output target only once") - } - emitter.write_handler = yaml_writer_write_handler - emitter.output_writer = w -} - -// Set the output encoding. -func yaml_emitter_set_encoding(emitter *yaml_emitter_t, encoding yaml_encoding_t) { - if emitter.encoding != yaml_ANY_ENCODING { - panic("must set the output encoding only once") - } - emitter.encoding = encoding -} - -// Set the canonical output style. -func yaml_emitter_set_canonical(emitter *yaml_emitter_t, canonical bool) { - emitter.canonical = canonical -} - -//// Set the indentation increment. -func yaml_emitter_set_indent(emitter *yaml_emitter_t, indent int) { - if indent < 2 || indent > 9 { - indent = 2 - } - emitter.best_indent = indent -} - -// Set the preferred line width. -func yaml_emitter_set_width(emitter *yaml_emitter_t, width int) { - if width < 0 { - width = -1 - } - emitter.best_width = width -} - -// Set if unescaped non-ASCII characters are allowed. -func yaml_emitter_set_unicode(emitter *yaml_emitter_t, unicode bool) { - emitter.unicode = unicode -} - -// Set the preferred line break character. -func yaml_emitter_set_break(emitter *yaml_emitter_t, line_break yaml_break_t) { - emitter.line_break = line_break -} - -///* -// * Destroy a token object. -// */ -// -//YAML_DECLARE(void) -//yaml_token_delete(yaml_token_t *token) -//{ -// assert(token); // Non-NULL token object expected. -// -// switch (token.type) -// { -// case YAML_TAG_DIRECTIVE_TOKEN: -// yaml_free(token.data.tag_directive.handle); -// yaml_free(token.data.tag_directive.prefix); -// break; -// -// case YAML_ALIAS_TOKEN: -// yaml_free(token.data.alias.value); -// break; -// -// case YAML_ANCHOR_TOKEN: -// yaml_free(token.data.anchor.value); -// break; -// -// case YAML_TAG_TOKEN: -// yaml_free(token.data.tag.handle); -// yaml_free(token.data.tag.suffix); -// break; -// -// case YAML_SCALAR_TOKEN: -// yaml_free(token.data.scalar.value); -// break; -// -// default: -// break; -// } -// -// memset(token, 0, sizeof(yaml_token_t)); -//} -// -///* -// * Check if a string is a valid UTF-8 sequence. -// * -// * Check 'reader.c' for more details on UTF-8 encoding. -// */ -// -//static int -//yaml_check_utf8(yaml_char_t *start, size_t length) -//{ -// yaml_char_t *end = start+length; -// yaml_char_t *pointer = start; -// -// while (pointer < end) { -// unsigned char octet; -// unsigned int width; -// unsigned int value; -// size_t k; -// -// octet = pointer[0]; -// width = (octet & 0x80) == 0x00 ? 1 : -// (octet & 0xE0) == 0xC0 ? 2 : -// (octet & 0xF0) == 0xE0 ? 3 : -// (octet & 0xF8) == 0xF0 ? 4 : 0; -// value = (octet & 0x80) == 0x00 ? octet & 0x7F : -// (octet & 0xE0) == 0xC0 ? octet & 0x1F : -// (octet & 0xF0) == 0xE0 ? octet & 0x0F : -// (octet & 0xF8) == 0xF0 ? octet & 0x07 : 0; -// if (!width) return 0; -// if (pointer+width > end) return 0; -// for (k = 1; k < width; k ++) { -// octet = pointer[k]; -// if ((octet & 0xC0) != 0x80) return 0; -// value = (value << 6) + (octet & 0x3F); -// } -// if (!((width == 1) || -// (width == 2 && value >= 0x80) || -// (width == 3 && value >= 0x800) || -// (width == 4 && value >= 0x10000))) return 0; -// -// pointer += width; -// } -// -// return 1; -//} -// - -// Create STREAM-START. -func yaml_stream_start_event_initialize(event *yaml_event_t, encoding yaml_encoding_t) { - *event = yaml_event_t{ - typ: yaml_STREAM_START_EVENT, - encoding: encoding, - } -} - -// Create STREAM-END. -func yaml_stream_end_event_initialize(event *yaml_event_t) { - *event = yaml_event_t{ - typ: yaml_STREAM_END_EVENT, - } -} - -// Create DOCUMENT-START. -func yaml_document_start_event_initialize( - event *yaml_event_t, - version_directive *yaml_version_directive_t, - tag_directives []yaml_tag_directive_t, - implicit bool, -) { - *event = yaml_event_t{ - typ: yaml_DOCUMENT_START_EVENT, - version_directive: version_directive, - tag_directives: tag_directives, - implicit: implicit, - } -} - -// Create DOCUMENT-END. -func yaml_document_end_event_initialize(event *yaml_event_t, implicit bool) { - *event = yaml_event_t{ - typ: yaml_DOCUMENT_END_EVENT, - implicit: implicit, - } -} - -///* -// * Create ALIAS. -// */ -// -//YAML_DECLARE(int) -//yaml_alias_event_initialize(event *yaml_event_t, anchor *yaml_char_t) -//{ -// mark yaml_mark_t = { 0, 0, 0 } -// anchor_copy *yaml_char_t = NULL -// -// assert(event) // Non-NULL event object is expected. -// assert(anchor) // Non-NULL anchor is expected. -// -// if (!yaml_check_utf8(anchor, strlen((char *)anchor))) return 0 -// -// anchor_copy = yaml_strdup(anchor) -// if (!anchor_copy) -// return 0 -// -// ALIAS_EVENT_INIT(*event, anchor_copy, mark, mark) -// -// return 1 -//} - -// Create SCALAR. -func yaml_scalar_event_initialize(event *yaml_event_t, anchor, tag, value []byte, plain_implicit, quoted_implicit bool, style yaml_scalar_style_t) bool { - *event = yaml_event_t{ - typ: yaml_SCALAR_EVENT, - anchor: anchor, - tag: tag, - value: value, - implicit: plain_implicit, - quoted_implicit: quoted_implicit, - style: yaml_style_t(style), - } - return true -} - -// Create SEQUENCE-START. -func yaml_sequence_start_event_initialize(event *yaml_event_t, anchor, tag []byte, implicit bool, style yaml_sequence_style_t) bool { - *event = yaml_event_t{ - typ: yaml_SEQUENCE_START_EVENT, - anchor: anchor, - tag: tag, - implicit: implicit, - style: yaml_style_t(style), - } - return true -} - -// Create SEQUENCE-END. -func yaml_sequence_end_event_initialize(event *yaml_event_t) bool { - *event = yaml_event_t{ - typ: yaml_SEQUENCE_END_EVENT, - } - return true -} - -// Create MAPPING-START. -func yaml_mapping_start_event_initialize(event *yaml_event_t, anchor, tag []byte, implicit bool, style yaml_mapping_style_t) { - *event = yaml_event_t{ - typ: yaml_MAPPING_START_EVENT, - anchor: anchor, - tag: tag, - implicit: implicit, - style: yaml_style_t(style), - } -} - -// Create MAPPING-END. -func yaml_mapping_end_event_initialize(event *yaml_event_t) { - *event = yaml_event_t{ - typ: yaml_MAPPING_END_EVENT, - } -} - -// Destroy an event object. -func yaml_event_delete(event *yaml_event_t) { - *event = yaml_event_t{} -} - -///* -// * Create a document object. -// */ -// -//YAML_DECLARE(int) -//yaml_document_initialize(document *yaml_document_t, -// version_directive *yaml_version_directive_t, -// tag_directives_start *yaml_tag_directive_t, -// tag_directives_end *yaml_tag_directive_t, -// start_implicit int, end_implicit int) -//{ -// struct { -// error yaml_error_type_t -// } context -// struct { -// start *yaml_node_t -// end *yaml_node_t -// top *yaml_node_t -// } nodes = { NULL, NULL, NULL } -// version_directive_copy *yaml_version_directive_t = NULL -// struct { -// start *yaml_tag_directive_t -// end *yaml_tag_directive_t -// top *yaml_tag_directive_t -// } tag_directives_copy = { NULL, NULL, NULL } -// value yaml_tag_directive_t = { NULL, NULL } -// mark yaml_mark_t = { 0, 0, 0 } -// -// assert(document) // Non-NULL document object is expected. -// assert((tag_directives_start && tag_directives_end) || -// (tag_directives_start == tag_directives_end)) -// // Valid tag directives are expected. -// -// if (!STACK_INIT(&context, nodes, INITIAL_STACK_SIZE)) goto error -// -// if (version_directive) { -// version_directive_copy = yaml_malloc(sizeof(yaml_version_directive_t)) -// if (!version_directive_copy) goto error -// version_directive_copy.major = version_directive.major -// version_directive_copy.minor = version_directive.minor -// } -// -// if (tag_directives_start != tag_directives_end) { -// tag_directive *yaml_tag_directive_t -// if (!STACK_INIT(&context, tag_directives_copy, INITIAL_STACK_SIZE)) -// goto error -// for (tag_directive = tag_directives_start -// tag_directive != tag_directives_end; tag_directive ++) { -// assert(tag_directive.handle) -// assert(tag_directive.prefix) -// if (!yaml_check_utf8(tag_directive.handle, -// strlen((char *)tag_directive.handle))) -// goto error -// if (!yaml_check_utf8(tag_directive.prefix, -// strlen((char *)tag_directive.prefix))) -// goto error -// value.handle = yaml_strdup(tag_directive.handle) -// value.prefix = yaml_strdup(tag_directive.prefix) -// if (!value.handle || !value.prefix) goto error -// if (!PUSH(&context, tag_directives_copy, value)) -// goto error -// value.handle = NULL -// value.prefix = NULL -// } -// } -// -// DOCUMENT_INIT(*document, nodes.start, nodes.end, version_directive_copy, -// tag_directives_copy.start, tag_directives_copy.top, -// start_implicit, end_implicit, mark, mark) -// -// return 1 -// -//error: -// STACK_DEL(&context, nodes) -// yaml_free(version_directive_copy) -// while (!STACK_EMPTY(&context, tag_directives_copy)) { -// value yaml_tag_directive_t = POP(&context, tag_directives_copy) -// yaml_free(value.handle) -// yaml_free(value.prefix) -// } -// STACK_DEL(&context, tag_directives_copy) -// yaml_free(value.handle) -// yaml_free(value.prefix) -// -// return 0 -//} -// -///* -// * Destroy a document object. -// */ -// -//YAML_DECLARE(void) -//yaml_document_delete(document *yaml_document_t) -//{ -// struct { -// error yaml_error_type_t -// } context -// tag_directive *yaml_tag_directive_t -// -// context.error = YAML_NO_ERROR // Eliminate a compiler warning. -// -// assert(document) // Non-NULL document object is expected. -// -// while (!STACK_EMPTY(&context, document.nodes)) { -// node yaml_node_t = POP(&context, document.nodes) -// yaml_free(node.tag) -// switch (node.type) { -// case YAML_SCALAR_NODE: -// yaml_free(node.data.scalar.value) -// break -// case YAML_SEQUENCE_NODE: -// STACK_DEL(&context, node.data.sequence.items) -// break -// case YAML_MAPPING_NODE: -// STACK_DEL(&context, node.data.mapping.pairs) -// break -// default: -// assert(0) // Should not happen. -// } -// } -// STACK_DEL(&context, document.nodes) -// -// yaml_free(document.version_directive) -// for (tag_directive = document.tag_directives.start -// tag_directive != document.tag_directives.end -// tag_directive++) { -// yaml_free(tag_directive.handle) -// yaml_free(tag_directive.prefix) -// } -// yaml_free(document.tag_directives.start) -// -// memset(document, 0, sizeof(yaml_document_t)) -//} -// -///** -// * Get a document node. -// */ -// -//YAML_DECLARE(yaml_node_t *) -//yaml_document_get_node(document *yaml_document_t, index int) -//{ -// assert(document) // Non-NULL document object is expected. -// -// if (index > 0 && document.nodes.start + index <= document.nodes.top) { -// return document.nodes.start + index - 1 -// } -// return NULL -//} -// -///** -// * Get the root object. -// */ -// -//YAML_DECLARE(yaml_node_t *) -//yaml_document_get_root_node(document *yaml_document_t) -//{ -// assert(document) // Non-NULL document object is expected. -// -// if (document.nodes.top != document.nodes.start) { -// return document.nodes.start -// } -// return NULL -//} -// -///* -// * Add a scalar node to a document. -// */ -// -//YAML_DECLARE(int) -//yaml_document_add_scalar(document *yaml_document_t, -// tag *yaml_char_t, value *yaml_char_t, length int, -// style yaml_scalar_style_t) -//{ -// struct { -// error yaml_error_type_t -// } context -// mark yaml_mark_t = { 0, 0, 0 } -// tag_copy *yaml_char_t = NULL -// value_copy *yaml_char_t = NULL -// node yaml_node_t -// -// assert(document) // Non-NULL document object is expected. -// assert(value) // Non-NULL value is expected. -// -// if (!tag) { -// tag = (yaml_char_t *)YAML_DEFAULT_SCALAR_TAG -// } -// -// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error -// tag_copy = yaml_strdup(tag) -// if (!tag_copy) goto error -// -// if (length < 0) { -// length = strlen((char *)value) -// } -// -// if (!yaml_check_utf8(value, length)) goto error -// value_copy = yaml_malloc(length+1) -// if (!value_copy) goto error -// memcpy(value_copy, value, length) -// value_copy[length] = '\0' -// -// SCALAR_NODE_INIT(node, tag_copy, value_copy, length, style, mark, mark) -// if (!PUSH(&context, document.nodes, node)) goto error -// -// return document.nodes.top - document.nodes.start -// -//error: -// yaml_free(tag_copy) -// yaml_free(value_copy) -// -// return 0 -//} -// -///* -// * Add a sequence node to a document. -// */ -// -//YAML_DECLARE(int) -//yaml_document_add_sequence(document *yaml_document_t, -// tag *yaml_char_t, style yaml_sequence_style_t) -//{ -// struct { -// error yaml_error_type_t -// } context -// mark yaml_mark_t = { 0, 0, 0 } -// tag_copy *yaml_char_t = NULL -// struct { -// start *yaml_node_item_t -// end *yaml_node_item_t -// top *yaml_node_item_t -// } items = { NULL, NULL, NULL } -// node yaml_node_t -// -// assert(document) // Non-NULL document object is expected. -// -// if (!tag) { -// tag = (yaml_char_t *)YAML_DEFAULT_SEQUENCE_TAG -// } -// -// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error -// tag_copy = yaml_strdup(tag) -// if (!tag_copy) goto error -// -// if (!STACK_INIT(&context, items, INITIAL_STACK_SIZE)) goto error -// -// SEQUENCE_NODE_INIT(node, tag_copy, items.start, items.end, -// style, mark, mark) -// if (!PUSH(&context, document.nodes, node)) goto error -// -// return document.nodes.top - document.nodes.start -// -//error: -// STACK_DEL(&context, items) -// yaml_free(tag_copy) -// -// return 0 -//} -// -///* -// * Add a mapping node to a document. -// */ -// -//YAML_DECLARE(int) -//yaml_document_add_mapping(document *yaml_document_t, -// tag *yaml_char_t, style yaml_mapping_style_t) -//{ -// struct { -// error yaml_error_type_t -// } context -// mark yaml_mark_t = { 0, 0, 0 } -// tag_copy *yaml_char_t = NULL -// struct { -// start *yaml_node_pair_t -// end *yaml_node_pair_t -// top *yaml_node_pair_t -// } pairs = { NULL, NULL, NULL } -// node yaml_node_t -// -// assert(document) // Non-NULL document object is expected. -// -// if (!tag) { -// tag = (yaml_char_t *)YAML_DEFAULT_MAPPING_TAG -// } -// -// if (!yaml_check_utf8(tag, strlen((char *)tag))) goto error -// tag_copy = yaml_strdup(tag) -// if (!tag_copy) goto error -// -// if (!STACK_INIT(&context, pairs, INITIAL_STACK_SIZE)) goto error -// -// MAPPING_NODE_INIT(node, tag_copy, pairs.start, pairs.end, -// style, mark, mark) -// if (!PUSH(&context, document.nodes, node)) goto error -// -// return document.nodes.top - document.nodes.start -// -//error: -// STACK_DEL(&context, pairs) -// yaml_free(tag_copy) -// -// return 0 -//} -// -///* -// * Append an item to a sequence node. -// */ -// -//YAML_DECLARE(int) -//yaml_document_append_sequence_item(document *yaml_document_t, -// sequence int, item int) -//{ -// struct { -// error yaml_error_type_t -// } context -// -// assert(document) // Non-NULL document is required. -// assert(sequence > 0 -// && document.nodes.start + sequence <= document.nodes.top) -// // Valid sequence id is required. -// assert(document.nodes.start[sequence-1].type == YAML_SEQUENCE_NODE) -// // A sequence node is required. -// assert(item > 0 && document.nodes.start + item <= document.nodes.top) -// // Valid item id is required. -// -// if (!PUSH(&context, -// document.nodes.start[sequence-1].data.sequence.items, item)) -// return 0 -// -// return 1 -//} -// -///* -// * Append a pair of a key and a value to a mapping node. -// */ -// -//YAML_DECLARE(int) -//yaml_document_append_mapping_pair(document *yaml_document_t, -// mapping int, key int, value int) -//{ -// struct { -// error yaml_error_type_t -// } context -// -// pair yaml_node_pair_t -// -// assert(document) // Non-NULL document is required. -// assert(mapping > 0 -// && document.nodes.start + mapping <= document.nodes.top) -// // Valid mapping id is required. -// assert(document.nodes.start[mapping-1].type == YAML_MAPPING_NODE) -// // A mapping node is required. -// assert(key > 0 && document.nodes.start + key <= document.nodes.top) -// // Valid key id is required. -// assert(value > 0 && document.nodes.start + value <= document.nodes.top) -// // Valid value id is required. -// -// pair.key = key -// pair.value = value -// -// if (!PUSH(&context, -// document.nodes.start[mapping-1].data.mapping.pairs, pair)) -// return 0 -// -// return 1 -//} -// -// diff --git a/vendor/gopkg.in/yaml.v2/decode.go b/vendor/gopkg.in/yaml.v2/decode.go deleted file mode 100644 index e4e56e28e0..0000000000 --- a/vendor/gopkg.in/yaml.v2/decode.go +++ /dev/null @@ -1,775 +0,0 @@ -package yaml - -import ( - "encoding" - "encoding/base64" - "fmt" - "io" - "math" - "reflect" - "strconv" - "time" -) - -const ( - documentNode = 1 << iota - mappingNode - sequenceNode - scalarNode - aliasNode -) - -type node struct { - kind int - line, column int - tag string - // For an alias node, alias holds the resolved alias. - alias *node - value string - implicit bool - children []*node - anchors map[string]*node -} - -// ---------------------------------------------------------------------------- -// Parser, produces a node tree out of a libyaml event stream. - -type parser struct { - parser yaml_parser_t - event yaml_event_t - doc *node - doneInit bool -} - -func newParser(b []byte) *parser { - p := parser{} - if !yaml_parser_initialize(&p.parser) { - panic("failed to initialize YAML emitter") - } - if len(b) == 0 { - b = []byte{'\n'} - } - yaml_parser_set_input_string(&p.parser, b) - return &p -} - -func newParserFromReader(r io.Reader) *parser { - p := parser{} - if !yaml_parser_initialize(&p.parser) { - panic("failed to initialize YAML emitter") - } - yaml_parser_set_input_reader(&p.parser, r) - return &p -} - -func (p *parser) init() { - if p.doneInit { - return - } - p.expect(yaml_STREAM_START_EVENT) - p.doneInit = true -} - -func (p *parser) destroy() { - if p.event.typ != yaml_NO_EVENT { - yaml_event_delete(&p.event) - } - yaml_parser_delete(&p.parser) -} - -// expect consumes an event from the event stream and -// checks that it's of the expected type. -func (p *parser) expect(e yaml_event_type_t) { - if p.event.typ == yaml_NO_EVENT { - if !yaml_parser_parse(&p.parser, &p.event) { - p.fail() - } - } - if p.event.typ == yaml_STREAM_END_EVENT { - failf("attempted to go past the end of stream; corrupted value?") - } - if p.event.typ != e { - p.parser.problem = fmt.Sprintf("expected %s event but got %s", e, p.event.typ) - p.fail() - } - yaml_event_delete(&p.event) - p.event.typ = yaml_NO_EVENT -} - -// peek peeks at the next event in the event stream, -// puts the results into p.event and returns the event type. -func (p *parser) peek() yaml_event_type_t { - if p.event.typ != yaml_NO_EVENT { - return p.event.typ - } - if !yaml_parser_parse(&p.parser, &p.event) { - p.fail() - } - return p.event.typ -} - -func (p *parser) fail() { - var where string - var line int - if p.parser.problem_mark.line != 0 { - line = p.parser.problem_mark.line - // Scanner errors don't iterate line before returning error - if p.parser.error == yaml_SCANNER_ERROR { - line++ - } - } else if p.parser.context_mark.line != 0 { - line = p.parser.context_mark.line - } - if line != 0 { - where = "line " + strconv.Itoa(line) + ": " - } - var msg string - if len(p.parser.problem) > 0 { - msg = p.parser.problem - } else { - msg = "unknown problem parsing YAML content" - } - failf("%s%s", where, msg) -} - -func (p *parser) anchor(n *node, anchor []byte) { - if anchor != nil { - p.doc.anchors[string(anchor)] = n - } -} - -func (p *parser) parse() *node { - p.init() - switch p.peek() { - case yaml_SCALAR_EVENT: - return p.scalar() - case yaml_ALIAS_EVENT: - return p.alias() - case yaml_MAPPING_START_EVENT: - return p.mapping() - case yaml_SEQUENCE_START_EVENT: - return p.sequence() - case yaml_DOCUMENT_START_EVENT: - return p.document() - case yaml_STREAM_END_EVENT: - // Happens when attempting to decode an empty buffer. - return nil - default: - panic("attempted to parse unknown event: " + p.event.typ.String()) - } -} - -func (p *parser) node(kind int) *node { - return &node{ - kind: kind, - line: p.event.start_mark.line, - column: p.event.start_mark.column, - } -} - -func (p *parser) document() *node { - n := p.node(documentNode) - n.anchors = make(map[string]*node) - p.doc = n - p.expect(yaml_DOCUMENT_START_EVENT) - n.children = append(n.children, p.parse()) - p.expect(yaml_DOCUMENT_END_EVENT) - return n -} - -func (p *parser) alias() *node { - n := p.node(aliasNode) - n.value = string(p.event.anchor) - n.alias = p.doc.anchors[n.value] - if n.alias == nil { - failf("unknown anchor '%s' referenced", n.value) - } - p.expect(yaml_ALIAS_EVENT) - return n -} - -func (p *parser) scalar() *node { - n := p.node(scalarNode) - n.value = string(p.event.value) - n.tag = string(p.event.tag) - n.implicit = p.event.implicit - p.anchor(n, p.event.anchor) - p.expect(yaml_SCALAR_EVENT) - return n -} - -func (p *parser) sequence() *node { - n := p.node(sequenceNode) - p.anchor(n, p.event.anchor) - p.expect(yaml_SEQUENCE_START_EVENT) - for p.peek() != yaml_SEQUENCE_END_EVENT { - n.children = append(n.children, p.parse()) - } - p.expect(yaml_SEQUENCE_END_EVENT) - return n -} - -func (p *parser) mapping() *node { - n := p.node(mappingNode) - p.anchor(n, p.event.anchor) - p.expect(yaml_MAPPING_START_EVENT) - for p.peek() != yaml_MAPPING_END_EVENT { - n.children = append(n.children, p.parse(), p.parse()) - } - p.expect(yaml_MAPPING_END_EVENT) - return n -} - -// ---------------------------------------------------------------------------- -// Decoder, unmarshals a node into a provided value. - -type decoder struct { - doc *node - aliases map[*node]bool - mapType reflect.Type - terrors []string - strict bool -} - -var ( - mapItemType = reflect.TypeOf(MapItem{}) - durationType = reflect.TypeOf(time.Duration(0)) - defaultMapType = reflect.TypeOf(map[interface{}]interface{}{}) - ifaceType = defaultMapType.Elem() - timeType = reflect.TypeOf(time.Time{}) - ptrTimeType = reflect.TypeOf(&time.Time{}) -) - -func newDecoder(strict bool) *decoder { - d := &decoder{mapType: defaultMapType, strict: strict} - d.aliases = make(map[*node]bool) - return d -} - -func (d *decoder) terror(n *node, tag string, out reflect.Value) { - if n.tag != "" { - tag = n.tag - } - value := n.value - if tag != yaml_SEQ_TAG && tag != yaml_MAP_TAG { - if len(value) > 10 { - value = " `" + value[:7] + "...`" - } else { - value = " `" + value + "`" - } - } - d.terrors = append(d.terrors, fmt.Sprintf("line %d: cannot unmarshal %s%s into %s", n.line+1, shortTag(tag), value, out.Type())) -} - -func (d *decoder) callUnmarshaler(n *node, u Unmarshaler) (good bool) { - terrlen := len(d.terrors) - err := u.UnmarshalYAML(func(v interface{}) (err error) { - defer handleErr(&err) - d.unmarshal(n, reflect.ValueOf(v)) - if len(d.terrors) > terrlen { - issues := d.terrors[terrlen:] - d.terrors = d.terrors[:terrlen] - return &TypeError{issues} - } - return nil - }) - if e, ok := err.(*TypeError); ok { - d.terrors = append(d.terrors, e.Errors...) - return false - } - if err != nil { - fail(err) - } - return true -} - -// d.prepare initializes and dereferences pointers and calls UnmarshalYAML -// if a value is found to implement it. -// It returns the initialized and dereferenced out value, whether -// unmarshalling was already done by UnmarshalYAML, and if so whether -// its types unmarshalled appropriately. -// -// If n holds a null value, prepare returns before doing anything. -func (d *decoder) prepare(n *node, out reflect.Value) (newout reflect.Value, unmarshaled, good bool) { - if n.tag == yaml_NULL_TAG || n.kind == scalarNode && n.tag == "" && (n.value == "null" || n.value == "~" || n.value == "" && n.implicit) { - return out, false, false - } - again := true - for again { - again = false - if out.Kind() == reflect.Ptr { - if out.IsNil() { - out.Set(reflect.New(out.Type().Elem())) - } - out = out.Elem() - again = true - } - if out.CanAddr() { - if u, ok := out.Addr().Interface().(Unmarshaler); ok { - good = d.callUnmarshaler(n, u) - return out, true, good - } - } - } - return out, false, false -} - -func (d *decoder) unmarshal(n *node, out reflect.Value) (good bool) { - switch n.kind { - case documentNode: - return d.document(n, out) - case aliasNode: - return d.alias(n, out) - } - out, unmarshaled, good := d.prepare(n, out) - if unmarshaled { - return good - } - switch n.kind { - case scalarNode: - good = d.scalar(n, out) - case mappingNode: - good = d.mapping(n, out) - case sequenceNode: - good = d.sequence(n, out) - default: - panic("internal error: unknown node kind: " + strconv.Itoa(n.kind)) - } - return good -} - -func (d *decoder) document(n *node, out reflect.Value) (good bool) { - if len(n.children) == 1 { - d.doc = n - d.unmarshal(n.children[0], out) - return true - } - return false -} - -func (d *decoder) alias(n *node, out reflect.Value) (good bool) { - if d.aliases[n] { - // TODO this could actually be allowed in some circumstances. - failf("anchor '%s' value contains itself", n.value) - } - d.aliases[n] = true - good = d.unmarshal(n.alias, out) - delete(d.aliases, n) - return good -} - -var zeroValue reflect.Value - -func resetMap(out reflect.Value) { - for _, k := range out.MapKeys() { - out.SetMapIndex(k, zeroValue) - } -} - -func (d *decoder) scalar(n *node, out reflect.Value) bool { - var tag string - var resolved interface{} - if n.tag == "" && !n.implicit { - tag = yaml_STR_TAG - resolved = n.value - } else { - tag, resolved = resolve(n.tag, n.value) - if tag == yaml_BINARY_TAG { - data, err := base64.StdEncoding.DecodeString(resolved.(string)) - if err != nil { - failf("!!binary value contains invalid base64 data") - } - resolved = string(data) - } - } - if resolved == nil { - if out.Kind() == reflect.Map && !out.CanAddr() { - resetMap(out) - } else { - out.Set(reflect.Zero(out.Type())) - } - return true - } - if resolvedv := reflect.ValueOf(resolved); out.Type() == resolvedv.Type() { - // We've resolved to exactly the type we want, so use that. - out.Set(resolvedv) - return true - } - // Perhaps we can use the value as a TextUnmarshaler to - // set its value. - if out.CanAddr() { - u, ok := out.Addr().Interface().(encoding.TextUnmarshaler) - if ok { - var text []byte - if tag == yaml_BINARY_TAG { - text = []byte(resolved.(string)) - } else { - // We let any value be unmarshaled into TextUnmarshaler. - // That might be more lax than we'd like, but the - // TextUnmarshaler itself should bowl out any dubious values. - text = []byte(n.value) - } - err := u.UnmarshalText(text) - if err != nil { - fail(err) - } - return true - } - } - switch out.Kind() { - case reflect.String: - if tag == yaml_BINARY_TAG { - out.SetString(resolved.(string)) - return true - } - if resolved != nil { - out.SetString(n.value) - return true - } - case reflect.Interface: - if resolved == nil { - out.Set(reflect.Zero(out.Type())) - } else if tag == yaml_TIMESTAMP_TAG { - // It looks like a timestamp but for backward compatibility - // reasons we set it as a string, so that code that unmarshals - // timestamp-like values into interface{} will continue to - // see a string and not a time.Time. - // TODO(v3) Drop this. - out.Set(reflect.ValueOf(n.value)) - } else { - out.Set(reflect.ValueOf(resolved)) - } - return true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch resolved := resolved.(type) { - case int: - if !out.OverflowInt(int64(resolved)) { - out.SetInt(int64(resolved)) - return true - } - case int64: - if !out.OverflowInt(resolved) { - out.SetInt(resolved) - return true - } - case uint64: - if resolved <= math.MaxInt64 && !out.OverflowInt(int64(resolved)) { - out.SetInt(int64(resolved)) - return true - } - case float64: - if resolved <= math.MaxInt64 && !out.OverflowInt(int64(resolved)) { - out.SetInt(int64(resolved)) - return true - } - case string: - if out.Type() == durationType { - d, err := time.ParseDuration(resolved) - if err == nil { - out.SetInt(int64(d)) - return true - } - } - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - switch resolved := resolved.(type) { - case int: - if resolved >= 0 && !out.OverflowUint(uint64(resolved)) { - out.SetUint(uint64(resolved)) - return true - } - case int64: - if resolved >= 0 && !out.OverflowUint(uint64(resolved)) { - out.SetUint(uint64(resolved)) - return true - } - case uint64: - if !out.OverflowUint(uint64(resolved)) { - out.SetUint(uint64(resolved)) - return true - } - case float64: - if resolved <= math.MaxUint64 && !out.OverflowUint(uint64(resolved)) { - out.SetUint(uint64(resolved)) - return true - } - } - case reflect.Bool: - switch resolved := resolved.(type) { - case bool: - out.SetBool(resolved) - return true - } - case reflect.Float32, reflect.Float64: - switch resolved := resolved.(type) { - case int: - out.SetFloat(float64(resolved)) - return true - case int64: - out.SetFloat(float64(resolved)) - return true - case uint64: - out.SetFloat(float64(resolved)) - return true - case float64: - out.SetFloat(resolved) - return true - } - case reflect.Struct: - if resolvedv := reflect.ValueOf(resolved); out.Type() == resolvedv.Type() { - out.Set(resolvedv) - return true - } - case reflect.Ptr: - if out.Type().Elem() == reflect.TypeOf(resolved) { - // TODO DOes this make sense? When is out a Ptr except when decoding a nil value? - elem := reflect.New(out.Type().Elem()) - elem.Elem().Set(reflect.ValueOf(resolved)) - out.Set(elem) - return true - } - } - d.terror(n, tag, out) - return false -} - -func settableValueOf(i interface{}) reflect.Value { - v := reflect.ValueOf(i) - sv := reflect.New(v.Type()).Elem() - sv.Set(v) - return sv -} - -func (d *decoder) sequence(n *node, out reflect.Value) (good bool) { - l := len(n.children) - - var iface reflect.Value - switch out.Kind() { - case reflect.Slice: - out.Set(reflect.MakeSlice(out.Type(), l, l)) - case reflect.Array: - if l != out.Len() { - failf("invalid array: want %d elements but got %d", out.Len(), l) - } - case reflect.Interface: - // No type hints. Will have to use a generic sequence. - iface = out - out = settableValueOf(make([]interface{}, l)) - default: - d.terror(n, yaml_SEQ_TAG, out) - return false - } - et := out.Type().Elem() - - j := 0 - for i := 0; i < l; i++ { - e := reflect.New(et).Elem() - if ok := d.unmarshal(n.children[i], e); ok { - out.Index(j).Set(e) - j++ - } - } - if out.Kind() != reflect.Array { - out.Set(out.Slice(0, j)) - } - if iface.IsValid() { - iface.Set(out) - } - return true -} - -func (d *decoder) mapping(n *node, out reflect.Value) (good bool) { - switch out.Kind() { - case reflect.Struct: - return d.mappingStruct(n, out) - case reflect.Slice: - return d.mappingSlice(n, out) - case reflect.Map: - // okay - case reflect.Interface: - if d.mapType.Kind() == reflect.Map { - iface := out - out = reflect.MakeMap(d.mapType) - iface.Set(out) - } else { - slicev := reflect.New(d.mapType).Elem() - if !d.mappingSlice(n, slicev) { - return false - } - out.Set(slicev) - return true - } - default: - d.terror(n, yaml_MAP_TAG, out) - return false - } - outt := out.Type() - kt := outt.Key() - et := outt.Elem() - - mapType := d.mapType - if outt.Key() == ifaceType && outt.Elem() == ifaceType { - d.mapType = outt - } - - if out.IsNil() { - out.Set(reflect.MakeMap(outt)) - } - l := len(n.children) - for i := 0; i < l; i += 2 { - if isMerge(n.children[i]) { - d.merge(n.children[i+1], out) - continue - } - k := reflect.New(kt).Elem() - if d.unmarshal(n.children[i], k) { - kkind := k.Kind() - if kkind == reflect.Interface { - kkind = k.Elem().Kind() - } - if kkind == reflect.Map || kkind == reflect.Slice { - failf("invalid map key: %#v", k.Interface()) - } - e := reflect.New(et).Elem() - if d.unmarshal(n.children[i+1], e) { - d.setMapIndex(n.children[i+1], out, k, e) - } - } - } - d.mapType = mapType - return true -} - -func (d *decoder) setMapIndex(n *node, out, k, v reflect.Value) { - if d.strict && out.MapIndex(k) != zeroValue { - d.terrors = append(d.terrors, fmt.Sprintf("line %d: key %#v already set in map", n.line+1, k.Interface())) - return - } - out.SetMapIndex(k, v) -} - -func (d *decoder) mappingSlice(n *node, out reflect.Value) (good bool) { - outt := out.Type() - if outt.Elem() != mapItemType { - d.terror(n, yaml_MAP_TAG, out) - return false - } - - mapType := d.mapType - d.mapType = outt - - var slice []MapItem - var l = len(n.children) - for i := 0; i < l; i += 2 { - if isMerge(n.children[i]) { - d.merge(n.children[i+1], out) - continue - } - item := MapItem{} - k := reflect.ValueOf(&item.Key).Elem() - if d.unmarshal(n.children[i], k) { - v := reflect.ValueOf(&item.Value).Elem() - if d.unmarshal(n.children[i+1], v) { - slice = append(slice, item) - } - } - } - out.Set(reflect.ValueOf(slice)) - d.mapType = mapType - return true -} - -func (d *decoder) mappingStruct(n *node, out reflect.Value) (good bool) { - sinfo, err := getStructInfo(out.Type()) - if err != nil { - panic(err) - } - name := settableValueOf("") - l := len(n.children) - - var inlineMap reflect.Value - var elemType reflect.Type - if sinfo.InlineMap != -1 { - inlineMap = out.Field(sinfo.InlineMap) - inlineMap.Set(reflect.New(inlineMap.Type()).Elem()) - elemType = inlineMap.Type().Elem() - } - - var doneFields []bool - if d.strict { - doneFields = make([]bool, len(sinfo.FieldsList)) - } - for i := 0; i < l; i += 2 { - ni := n.children[i] - if isMerge(ni) { - d.merge(n.children[i+1], out) - continue - } - if !d.unmarshal(ni, name) { - continue - } - if info, ok := sinfo.FieldsMap[name.String()]; ok { - if d.strict { - if doneFields[info.Id] { - d.terrors = append(d.terrors, fmt.Sprintf("line %d: field %s already set in type %s", ni.line+1, name.String(), out.Type())) - continue - } - doneFields[info.Id] = true - } - var field reflect.Value - if info.Inline == nil { - field = out.Field(info.Num) - } else { - field = out.FieldByIndex(info.Inline) - } - d.unmarshal(n.children[i+1], field) - } else if sinfo.InlineMap != -1 { - if inlineMap.IsNil() { - inlineMap.Set(reflect.MakeMap(inlineMap.Type())) - } - value := reflect.New(elemType).Elem() - d.unmarshal(n.children[i+1], value) - d.setMapIndex(n.children[i+1], inlineMap, name, value) - } else if d.strict { - d.terrors = append(d.terrors, fmt.Sprintf("line %d: field %s not found in type %s", ni.line+1, name.String(), out.Type())) - } - } - return true -} - -func failWantMap() { - failf("map merge requires map or sequence of maps as the value") -} - -func (d *decoder) merge(n *node, out reflect.Value) { - switch n.kind { - case mappingNode: - d.unmarshal(n, out) - case aliasNode: - an, ok := d.doc.anchors[n.value] - if ok && an.kind != mappingNode { - failWantMap() - } - d.unmarshal(n, out) - case sequenceNode: - // Step backwards as earlier nodes take precedence. - for i := len(n.children) - 1; i >= 0; i-- { - ni := n.children[i] - if ni.kind == aliasNode { - an, ok := d.doc.anchors[ni.value] - if ok && an.kind != mappingNode { - failWantMap() - } - } else if ni.kind != mappingNode { - failWantMap() - } - d.unmarshal(ni, out) - } - default: - failWantMap() - } -} - -func isMerge(n *node) bool { - return n.kind == scalarNode && n.value == "<<" && (n.implicit == true || n.tag == yaml_MERGE_TAG) -} diff --git a/vendor/gopkg.in/yaml.v2/emitterc.go b/vendor/gopkg.in/yaml.v2/emitterc.go deleted file mode 100644 index a1c2cc5262..0000000000 --- a/vendor/gopkg.in/yaml.v2/emitterc.go +++ /dev/null @@ -1,1685 +0,0 @@ -package yaml - -import ( - "bytes" - "fmt" -) - -// Flush the buffer if needed. -func flush(emitter *yaml_emitter_t) bool { - if emitter.buffer_pos+5 >= len(emitter.buffer) { - return yaml_emitter_flush(emitter) - } - return true -} - -// Put a character to the output buffer. -func put(emitter *yaml_emitter_t, value byte) bool { - if emitter.buffer_pos+5 >= len(emitter.buffer) && !yaml_emitter_flush(emitter) { - return false - } - emitter.buffer[emitter.buffer_pos] = value - emitter.buffer_pos++ - emitter.column++ - return true -} - -// Put a line break to the output buffer. -func put_break(emitter *yaml_emitter_t) bool { - if emitter.buffer_pos+5 >= len(emitter.buffer) && !yaml_emitter_flush(emitter) { - return false - } - switch emitter.line_break { - case yaml_CR_BREAK: - emitter.buffer[emitter.buffer_pos] = '\r' - emitter.buffer_pos += 1 - case yaml_LN_BREAK: - emitter.buffer[emitter.buffer_pos] = '\n' - emitter.buffer_pos += 1 - case yaml_CRLN_BREAK: - emitter.buffer[emitter.buffer_pos+0] = '\r' - emitter.buffer[emitter.buffer_pos+1] = '\n' - emitter.buffer_pos += 2 - default: - panic("unknown line break setting") - } - emitter.column = 0 - emitter.line++ - return true -} - -// Copy a character from a string into buffer. -func write(emitter *yaml_emitter_t, s []byte, i *int) bool { - if emitter.buffer_pos+5 >= len(emitter.buffer) && !yaml_emitter_flush(emitter) { - return false - } - p := emitter.buffer_pos - w := width(s[*i]) - switch w { - case 4: - emitter.buffer[p+3] = s[*i+3] - fallthrough - case 3: - emitter.buffer[p+2] = s[*i+2] - fallthrough - case 2: - emitter.buffer[p+1] = s[*i+1] - fallthrough - case 1: - emitter.buffer[p+0] = s[*i+0] - default: - panic("unknown character width") - } - emitter.column++ - emitter.buffer_pos += w - *i += w - return true -} - -// Write a whole string into buffer. -func write_all(emitter *yaml_emitter_t, s []byte) bool { - for i := 0; i < len(s); { - if !write(emitter, s, &i) { - return false - } - } - return true -} - -// Copy a line break character from a string into buffer. -func write_break(emitter *yaml_emitter_t, s []byte, i *int) bool { - if s[*i] == '\n' { - if !put_break(emitter) { - return false - } - *i++ - } else { - if !write(emitter, s, i) { - return false - } - emitter.column = 0 - emitter.line++ - } - return true -} - -// Set an emitter error and return false. -func yaml_emitter_set_emitter_error(emitter *yaml_emitter_t, problem string) bool { - emitter.error = yaml_EMITTER_ERROR - emitter.problem = problem - return false -} - -// Emit an event. -func yaml_emitter_emit(emitter *yaml_emitter_t, event *yaml_event_t) bool { - emitter.events = append(emitter.events, *event) - for !yaml_emitter_need_more_events(emitter) { - event := &emitter.events[emitter.events_head] - if !yaml_emitter_analyze_event(emitter, event) { - return false - } - if !yaml_emitter_state_machine(emitter, event) { - return false - } - yaml_event_delete(event) - emitter.events_head++ - } - return true -} - -// Check if we need to accumulate more events before emitting. -// -// We accumulate extra -// - 1 event for DOCUMENT-START -// - 2 events for SEQUENCE-START -// - 3 events for MAPPING-START -// -func yaml_emitter_need_more_events(emitter *yaml_emitter_t) bool { - if emitter.events_head == len(emitter.events) { - return true - } - var accumulate int - switch emitter.events[emitter.events_head].typ { - case yaml_DOCUMENT_START_EVENT: - accumulate = 1 - break - case yaml_SEQUENCE_START_EVENT: - accumulate = 2 - break - case yaml_MAPPING_START_EVENT: - accumulate = 3 - break - default: - return false - } - if len(emitter.events)-emitter.events_head > accumulate { - return false - } - var level int - for i := emitter.events_head; i < len(emitter.events); i++ { - switch emitter.events[i].typ { - case yaml_STREAM_START_EVENT, yaml_DOCUMENT_START_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT: - level++ - case yaml_STREAM_END_EVENT, yaml_DOCUMENT_END_EVENT, yaml_SEQUENCE_END_EVENT, yaml_MAPPING_END_EVENT: - level-- - } - if level == 0 { - return false - } - } - return true -} - -// Append a directive to the directives stack. -func yaml_emitter_append_tag_directive(emitter *yaml_emitter_t, value *yaml_tag_directive_t, allow_duplicates bool) bool { - for i := 0; i < len(emitter.tag_directives); i++ { - if bytes.Equal(value.handle, emitter.tag_directives[i].handle) { - if allow_duplicates { - return true - } - return yaml_emitter_set_emitter_error(emitter, "duplicate %TAG directive") - } - } - - // [Go] Do we actually need to copy this given garbage collection - // and the lack of deallocating destructors? - tag_copy := yaml_tag_directive_t{ - handle: make([]byte, len(value.handle)), - prefix: make([]byte, len(value.prefix)), - } - copy(tag_copy.handle, value.handle) - copy(tag_copy.prefix, value.prefix) - emitter.tag_directives = append(emitter.tag_directives, tag_copy) - return true -} - -// Increase the indentation level. -func yaml_emitter_increase_indent(emitter *yaml_emitter_t, flow, indentless bool) bool { - emitter.indents = append(emitter.indents, emitter.indent) - if emitter.indent < 0 { - if flow { - emitter.indent = emitter.best_indent - } else { - emitter.indent = 0 - } - } else if !indentless { - emitter.indent += emitter.best_indent - } - return true -} - -// State dispatcher. -func yaml_emitter_state_machine(emitter *yaml_emitter_t, event *yaml_event_t) bool { - switch emitter.state { - default: - case yaml_EMIT_STREAM_START_STATE: - return yaml_emitter_emit_stream_start(emitter, event) - - case yaml_EMIT_FIRST_DOCUMENT_START_STATE: - return yaml_emitter_emit_document_start(emitter, event, true) - - case yaml_EMIT_DOCUMENT_START_STATE: - return yaml_emitter_emit_document_start(emitter, event, false) - - case yaml_EMIT_DOCUMENT_CONTENT_STATE: - return yaml_emitter_emit_document_content(emitter, event) - - case yaml_EMIT_DOCUMENT_END_STATE: - return yaml_emitter_emit_document_end(emitter, event) - - case yaml_EMIT_FLOW_SEQUENCE_FIRST_ITEM_STATE: - return yaml_emitter_emit_flow_sequence_item(emitter, event, true) - - case yaml_EMIT_FLOW_SEQUENCE_ITEM_STATE: - return yaml_emitter_emit_flow_sequence_item(emitter, event, false) - - case yaml_EMIT_FLOW_MAPPING_FIRST_KEY_STATE: - return yaml_emitter_emit_flow_mapping_key(emitter, event, true) - - case yaml_EMIT_FLOW_MAPPING_KEY_STATE: - return yaml_emitter_emit_flow_mapping_key(emitter, event, false) - - case yaml_EMIT_FLOW_MAPPING_SIMPLE_VALUE_STATE: - return yaml_emitter_emit_flow_mapping_value(emitter, event, true) - - case yaml_EMIT_FLOW_MAPPING_VALUE_STATE: - return yaml_emitter_emit_flow_mapping_value(emitter, event, false) - - case yaml_EMIT_BLOCK_SEQUENCE_FIRST_ITEM_STATE: - return yaml_emitter_emit_block_sequence_item(emitter, event, true) - - case yaml_EMIT_BLOCK_SEQUENCE_ITEM_STATE: - return yaml_emitter_emit_block_sequence_item(emitter, event, false) - - case yaml_EMIT_BLOCK_MAPPING_FIRST_KEY_STATE: - return yaml_emitter_emit_block_mapping_key(emitter, event, true) - - case yaml_EMIT_BLOCK_MAPPING_KEY_STATE: - return yaml_emitter_emit_block_mapping_key(emitter, event, false) - - case yaml_EMIT_BLOCK_MAPPING_SIMPLE_VALUE_STATE: - return yaml_emitter_emit_block_mapping_value(emitter, event, true) - - case yaml_EMIT_BLOCK_MAPPING_VALUE_STATE: - return yaml_emitter_emit_block_mapping_value(emitter, event, false) - - case yaml_EMIT_END_STATE: - return yaml_emitter_set_emitter_error(emitter, "expected nothing after STREAM-END") - } - panic("invalid emitter state") -} - -// Expect STREAM-START. -func yaml_emitter_emit_stream_start(emitter *yaml_emitter_t, event *yaml_event_t) bool { - if event.typ != yaml_STREAM_START_EVENT { - return yaml_emitter_set_emitter_error(emitter, "expected STREAM-START") - } - if emitter.encoding == yaml_ANY_ENCODING { - emitter.encoding = event.encoding - if emitter.encoding == yaml_ANY_ENCODING { - emitter.encoding = yaml_UTF8_ENCODING - } - } - if emitter.best_indent < 2 || emitter.best_indent > 9 { - emitter.best_indent = 2 - } - if emitter.best_width >= 0 && emitter.best_width <= emitter.best_indent*2 { - emitter.best_width = 80 - } - if emitter.best_width < 0 { - emitter.best_width = 1<<31 - 1 - } - if emitter.line_break == yaml_ANY_BREAK { - emitter.line_break = yaml_LN_BREAK - } - - emitter.indent = -1 - emitter.line = 0 - emitter.column = 0 - emitter.whitespace = true - emitter.indention = true - - if emitter.encoding != yaml_UTF8_ENCODING { - if !yaml_emitter_write_bom(emitter) { - return false - } - } - emitter.state = yaml_EMIT_FIRST_DOCUMENT_START_STATE - return true -} - -// Expect DOCUMENT-START or STREAM-END. -func yaml_emitter_emit_document_start(emitter *yaml_emitter_t, event *yaml_event_t, first bool) bool { - - if event.typ == yaml_DOCUMENT_START_EVENT { - - if event.version_directive != nil { - if !yaml_emitter_analyze_version_directive(emitter, event.version_directive) { - return false - } - } - - for i := 0; i < len(event.tag_directives); i++ { - tag_directive := &event.tag_directives[i] - if !yaml_emitter_analyze_tag_directive(emitter, tag_directive) { - return false - } - if !yaml_emitter_append_tag_directive(emitter, tag_directive, false) { - return false - } - } - - for i := 0; i < len(default_tag_directives); i++ { - tag_directive := &default_tag_directives[i] - if !yaml_emitter_append_tag_directive(emitter, tag_directive, true) { - return false - } - } - - implicit := event.implicit - if !first || emitter.canonical { - implicit = false - } - - if emitter.open_ended && (event.version_directive != nil || len(event.tag_directives) > 0) { - if !yaml_emitter_write_indicator(emitter, []byte("..."), true, false, false) { - return false - } - if !yaml_emitter_write_indent(emitter) { - return false - } - } - - if event.version_directive != nil { - implicit = false - if !yaml_emitter_write_indicator(emitter, []byte("%YAML"), true, false, false) { - return false - } - if !yaml_emitter_write_indicator(emitter, []byte("1.1"), true, false, false) { - return false - } - if !yaml_emitter_write_indent(emitter) { - return false - } - } - - if len(event.tag_directives) > 0 { - implicit = false - for i := 0; i < len(event.tag_directives); i++ { - tag_directive := &event.tag_directives[i] - if !yaml_emitter_write_indicator(emitter, []byte("%TAG"), true, false, false) { - return false - } - if !yaml_emitter_write_tag_handle(emitter, tag_directive.handle) { - return false - } - if !yaml_emitter_write_tag_content(emitter, tag_directive.prefix, true) { - return false - } - if !yaml_emitter_write_indent(emitter) { - return false - } - } - } - - if yaml_emitter_check_empty_document(emitter) { - implicit = false - } - if !implicit { - if !yaml_emitter_write_indent(emitter) { - return false - } - if !yaml_emitter_write_indicator(emitter, []byte("---"), true, false, false) { - return false - } - if emitter.canonical { - if !yaml_emitter_write_indent(emitter) { - return false - } - } - } - - emitter.state = yaml_EMIT_DOCUMENT_CONTENT_STATE - return true - } - - if event.typ == yaml_STREAM_END_EVENT { - if emitter.open_ended { - if !yaml_emitter_write_indicator(emitter, []byte("..."), true, false, false) { - return false - } - if !yaml_emitter_write_indent(emitter) { - return false - } - } - if !yaml_emitter_flush(emitter) { - return false - } - emitter.state = yaml_EMIT_END_STATE - return true - } - - return yaml_emitter_set_emitter_error(emitter, "expected DOCUMENT-START or STREAM-END") -} - -// Expect the root node. -func yaml_emitter_emit_document_content(emitter *yaml_emitter_t, event *yaml_event_t) bool { - emitter.states = append(emitter.states, yaml_EMIT_DOCUMENT_END_STATE) - return yaml_emitter_emit_node(emitter, event, true, false, false, false) -} - -// Expect DOCUMENT-END. -func yaml_emitter_emit_document_end(emitter *yaml_emitter_t, event *yaml_event_t) bool { - if event.typ != yaml_DOCUMENT_END_EVENT { - return yaml_emitter_set_emitter_error(emitter, "expected DOCUMENT-END") - } - if !yaml_emitter_write_indent(emitter) { - return false - } - if !event.implicit { - // [Go] Allocate the slice elsewhere. - if !yaml_emitter_write_indicator(emitter, []byte("..."), true, false, false) { - return false - } - if !yaml_emitter_write_indent(emitter) { - return false - } - } - if !yaml_emitter_flush(emitter) { - return false - } - emitter.state = yaml_EMIT_DOCUMENT_START_STATE - emitter.tag_directives = emitter.tag_directives[:0] - return true -} - -// Expect a flow item node. -func yaml_emitter_emit_flow_sequence_item(emitter *yaml_emitter_t, event *yaml_event_t, first bool) bool { - if first { - if !yaml_emitter_write_indicator(emitter, []byte{'['}, true, true, false) { - return false - } - if !yaml_emitter_increase_indent(emitter, true, false) { - return false - } - emitter.flow_level++ - } - - if event.typ == yaml_SEQUENCE_END_EVENT { - emitter.flow_level-- - emitter.indent = emitter.indents[len(emitter.indents)-1] - emitter.indents = emitter.indents[:len(emitter.indents)-1] - if emitter.canonical && !first { - if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { - return false - } - if !yaml_emitter_write_indent(emitter) { - return false - } - } - if !yaml_emitter_write_indicator(emitter, []byte{']'}, false, false, false) { - return false - } - emitter.state = emitter.states[len(emitter.states)-1] - emitter.states = emitter.states[:len(emitter.states)-1] - - return true - } - - if !first { - if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { - return false - } - } - - if emitter.canonical || emitter.column > emitter.best_width { - if !yaml_emitter_write_indent(emitter) { - return false - } - } - emitter.states = append(emitter.states, yaml_EMIT_FLOW_SEQUENCE_ITEM_STATE) - return yaml_emitter_emit_node(emitter, event, false, true, false, false) -} - -// Expect a flow key node. -func yaml_emitter_emit_flow_mapping_key(emitter *yaml_emitter_t, event *yaml_event_t, first bool) bool { - if first { - if !yaml_emitter_write_indicator(emitter, []byte{'{'}, true, true, false) { - return false - } - if !yaml_emitter_increase_indent(emitter, true, false) { - return false - } - emitter.flow_level++ - } - - if event.typ == yaml_MAPPING_END_EVENT { - emitter.flow_level-- - emitter.indent = emitter.indents[len(emitter.indents)-1] - emitter.indents = emitter.indents[:len(emitter.indents)-1] - if emitter.canonical && !first { - if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { - return false - } - if !yaml_emitter_write_indent(emitter) { - return false - } - } - if !yaml_emitter_write_indicator(emitter, []byte{'}'}, false, false, false) { - return false - } - emitter.state = emitter.states[len(emitter.states)-1] - emitter.states = emitter.states[:len(emitter.states)-1] - return true - } - - if !first { - if !yaml_emitter_write_indicator(emitter, []byte{','}, false, false, false) { - return false - } - } - if emitter.canonical || emitter.column > emitter.best_width { - if !yaml_emitter_write_indent(emitter) { - return false - } - } - - if !emitter.canonical && yaml_emitter_check_simple_key(emitter) { - emitter.states = append(emitter.states, yaml_EMIT_FLOW_MAPPING_SIMPLE_VALUE_STATE) - return yaml_emitter_emit_node(emitter, event, false, false, true, true) - } - if !yaml_emitter_write_indicator(emitter, []byte{'?'}, true, false, false) { - return false - } - emitter.states = append(emitter.states, yaml_EMIT_FLOW_MAPPING_VALUE_STATE) - return yaml_emitter_emit_node(emitter, event, false, false, true, false) -} - -// Expect a flow value node. -func yaml_emitter_emit_flow_mapping_value(emitter *yaml_emitter_t, event *yaml_event_t, simple bool) bool { - if simple { - if !yaml_emitter_write_indicator(emitter, []byte{':'}, false, false, false) { - return false - } - } else { - if emitter.canonical || emitter.column > emitter.best_width { - if !yaml_emitter_write_indent(emitter) { - return false - } - } - if !yaml_emitter_write_indicator(emitter, []byte{':'}, true, false, false) { - return false - } - } - emitter.states = append(emitter.states, yaml_EMIT_FLOW_MAPPING_KEY_STATE) - return yaml_emitter_emit_node(emitter, event, false, false, true, false) -} - -// Expect a block item node. -func yaml_emitter_emit_block_sequence_item(emitter *yaml_emitter_t, event *yaml_event_t, first bool) bool { - if first { - if !yaml_emitter_increase_indent(emitter, false, emitter.mapping_context && !emitter.indention) { - return false - } - } - if event.typ == yaml_SEQUENCE_END_EVENT { - emitter.indent = emitter.indents[len(emitter.indents)-1] - emitter.indents = emitter.indents[:len(emitter.indents)-1] - emitter.state = emitter.states[len(emitter.states)-1] - emitter.states = emitter.states[:len(emitter.states)-1] - return true - } - if !yaml_emitter_write_indent(emitter) { - return false - } - if !yaml_emitter_write_indicator(emitter, []byte{'-'}, true, false, true) { - return false - } - emitter.states = append(emitter.states, yaml_EMIT_BLOCK_SEQUENCE_ITEM_STATE) - return yaml_emitter_emit_node(emitter, event, false, true, false, false) -} - -// Expect a block key node. -func yaml_emitter_emit_block_mapping_key(emitter *yaml_emitter_t, event *yaml_event_t, first bool) bool { - if first { - if !yaml_emitter_increase_indent(emitter, false, false) { - return false - } - } - if event.typ == yaml_MAPPING_END_EVENT { - emitter.indent = emitter.indents[len(emitter.indents)-1] - emitter.indents = emitter.indents[:len(emitter.indents)-1] - emitter.state = emitter.states[len(emitter.states)-1] - emitter.states = emitter.states[:len(emitter.states)-1] - return true - } - if !yaml_emitter_write_indent(emitter) { - return false - } - if yaml_emitter_check_simple_key(emitter) { - emitter.states = append(emitter.states, yaml_EMIT_BLOCK_MAPPING_SIMPLE_VALUE_STATE) - return yaml_emitter_emit_node(emitter, event, false, false, true, true) - } - if !yaml_emitter_write_indicator(emitter, []byte{'?'}, true, false, true) { - return false - } - emitter.states = append(emitter.states, yaml_EMIT_BLOCK_MAPPING_VALUE_STATE) - return yaml_emitter_emit_node(emitter, event, false, false, true, false) -} - -// Expect a block value node. -func yaml_emitter_emit_block_mapping_value(emitter *yaml_emitter_t, event *yaml_event_t, simple bool) bool { - if simple { - if !yaml_emitter_write_indicator(emitter, []byte{':'}, false, false, false) { - return false - } - } else { - if !yaml_emitter_write_indent(emitter) { - return false - } - if !yaml_emitter_write_indicator(emitter, []byte{':'}, true, false, true) { - return false - } - } - emitter.states = append(emitter.states, yaml_EMIT_BLOCK_MAPPING_KEY_STATE) - return yaml_emitter_emit_node(emitter, event, false, false, true, false) -} - -// Expect a node. -func yaml_emitter_emit_node(emitter *yaml_emitter_t, event *yaml_event_t, - root bool, sequence bool, mapping bool, simple_key bool) bool { - - emitter.root_context = root - emitter.sequence_context = sequence - emitter.mapping_context = mapping - emitter.simple_key_context = simple_key - - switch event.typ { - case yaml_ALIAS_EVENT: - return yaml_emitter_emit_alias(emitter, event) - case yaml_SCALAR_EVENT: - return yaml_emitter_emit_scalar(emitter, event) - case yaml_SEQUENCE_START_EVENT: - return yaml_emitter_emit_sequence_start(emitter, event) - case yaml_MAPPING_START_EVENT: - return yaml_emitter_emit_mapping_start(emitter, event) - default: - return yaml_emitter_set_emitter_error(emitter, - fmt.Sprintf("expected SCALAR, SEQUENCE-START, MAPPING-START, or ALIAS, but got %v", event.typ)) - } -} - -// Expect ALIAS. -func yaml_emitter_emit_alias(emitter *yaml_emitter_t, event *yaml_event_t) bool { - if !yaml_emitter_process_anchor(emitter) { - return false - } - emitter.state = emitter.states[len(emitter.states)-1] - emitter.states = emitter.states[:len(emitter.states)-1] - return true -} - -// Expect SCALAR. -func yaml_emitter_emit_scalar(emitter *yaml_emitter_t, event *yaml_event_t) bool { - if !yaml_emitter_select_scalar_style(emitter, event) { - return false - } - if !yaml_emitter_process_anchor(emitter) { - return false - } - if !yaml_emitter_process_tag(emitter) { - return false - } - if !yaml_emitter_increase_indent(emitter, true, false) { - return false - } - if !yaml_emitter_process_scalar(emitter) { - return false - } - emitter.indent = emitter.indents[len(emitter.indents)-1] - emitter.indents = emitter.indents[:len(emitter.indents)-1] - emitter.state = emitter.states[len(emitter.states)-1] - emitter.states = emitter.states[:len(emitter.states)-1] - return true -} - -// Expect SEQUENCE-START. -func yaml_emitter_emit_sequence_start(emitter *yaml_emitter_t, event *yaml_event_t) bool { - if !yaml_emitter_process_anchor(emitter) { - return false - } - if !yaml_emitter_process_tag(emitter) { - return false - } - if emitter.flow_level > 0 || emitter.canonical || event.sequence_style() == yaml_FLOW_SEQUENCE_STYLE || - yaml_emitter_check_empty_sequence(emitter) { - emitter.state = yaml_EMIT_FLOW_SEQUENCE_FIRST_ITEM_STATE - } else { - emitter.state = yaml_EMIT_BLOCK_SEQUENCE_FIRST_ITEM_STATE - } - return true -} - -// Expect MAPPING-START. -func yaml_emitter_emit_mapping_start(emitter *yaml_emitter_t, event *yaml_event_t) bool { - if !yaml_emitter_process_anchor(emitter) { - return false - } - if !yaml_emitter_process_tag(emitter) { - return false - } - if emitter.flow_level > 0 || emitter.canonical || event.mapping_style() == yaml_FLOW_MAPPING_STYLE || - yaml_emitter_check_empty_mapping(emitter) { - emitter.state = yaml_EMIT_FLOW_MAPPING_FIRST_KEY_STATE - } else { - emitter.state = yaml_EMIT_BLOCK_MAPPING_FIRST_KEY_STATE - } - return true -} - -// Check if the document content is an empty scalar. -func yaml_emitter_check_empty_document(emitter *yaml_emitter_t) bool { - return false // [Go] Huh? -} - -// Check if the next events represent an empty sequence. -func yaml_emitter_check_empty_sequence(emitter *yaml_emitter_t) bool { - if len(emitter.events)-emitter.events_head < 2 { - return false - } - return emitter.events[emitter.events_head].typ == yaml_SEQUENCE_START_EVENT && - emitter.events[emitter.events_head+1].typ == yaml_SEQUENCE_END_EVENT -} - -// Check if the next events represent an empty mapping. -func yaml_emitter_check_empty_mapping(emitter *yaml_emitter_t) bool { - if len(emitter.events)-emitter.events_head < 2 { - return false - } - return emitter.events[emitter.events_head].typ == yaml_MAPPING_START_EVENT && - emitter.events[emitter.events_head+1].typ == yaml_MAPPING_END_EVENT -} - -// Check if the next node can be expressed as a simple key. -func yaml_emitter_check_simple_key(emitter *yaml_emitter_t) bool { - length := 0 - switch emitter.events[emitter.events_head].typ { - case yaml_ALIAS_EVENT: - length += len(emitter.anchor_data.anchor) - case yaml_SCALAR_EVENT: - if emitter.scalar_data.multiline { - return false - } - length += len(emitter.anchor_data.anchor) + - len(emitter.tag_data.handle) + - len(emitter.tag_data.suffix) + - len(emitter.scalar_data.value) - case yaml_SEQUENCE_START_EVENT: - if !yaml_emitter_check_empty_sequence(emitter) { - return false - } - length += len(emitter.anchor_data.anchor) + - len(emitter.tag_data.handle) + - len(emitter.tag_data.suffix) - case yaml_MAPPING_START_EVENT: - if !yaml_emitter_check_empty_mapping(emitter) { - return false - } - length += len(emitter.anchor_data.anchor) + - len(emitter.tag_data.handle) + - len(emitter.tag_data.suffix) - default: - return false - } - return length <= 128 -} - -// Determine an acceptable scalar style. -func yaml_emitter_select_scalar_style(emitter *yaml_emitter_t, event *yaml_event_t) bool { - - no_tag := len(emitter.tag_data.handle) == 0 && len(emitter.tag_data.suffix) == 0 - if no_tag && !event.implicit && !event.quoted_implicit { - return yaml_emitter_set_emitter_error(emitter, "neither tag nor implicit flags are specified") - } - - style := event.scalar_style() - if style == yaml_ANY_SCALAR_STYLE { - style = yaml_PLAIN_SCALAR_STYLE - } - if emitter.canonical { - style = yaml_DOUBLE_QUOTED_SCALAR_STYLE - } - if emitter.simple_key_context && emitter.scalar_data.multiline { - style = yaml_DOUBLE_QUOTED_SCALAR_STYLE - } - - if style == yaml_PLAIN_SCALAR_STYLE { - if emitter.flow_level > 0 && !emitter.scalar_data.flow_plain_allowed || - emitter.flow_level == 0 && !emitter.scalar_data.block_plain_allowed { - style = yaml_SINGLE_QUOTED_SCALAR_STYLE - } - if len(emitter.scalar_data.value) == 0 && (emitter.flow_level > 0 || emitter.simple_key_context) { - style = yaml_SINGLE_QUOTED_SCALAR_STYLE - } - if no_tag && !event.implicit { - style = yaml_SINGLE_QUOTED_SCALAR_STYLE - } - } - if style == yaml_SINGLE_QUOTED_SCALAR_STYLE { - if !emitter.scalar_data.single_quoted_allowed { - style = yaml_DOUBLE_QUOTED_SCALAR_STYLE - } - } - if style == yaml_LITERAL_SCALAR_STYLE || style == yaml_FOLDED_SCALAR_STYLE { - if !emitter.scalar_data.block_allowed || emitter.flow_level > 0 || emitter.simple_key_context { - style = yaml_DOUBLE_QUOTED_SCALAR_STYLE - } - } - - if no_tag && !event.quoted_implicit && style != yaml_PLAIN_SCALAR_STYLE { - emitter.tag_data.handle = []byte{'!'} - } - emitter.scalar_data.style = style - return true -} - -// Write an anchor. -func yaml_emitter_process_anchor(emitter *yaml_emitter_t) bool { - if emitter.anchor_data.anchor == nil { - return true - } - c := []byte{'&'} - if emitter.anchor_data.alias { - c[0] = '*' - } - if !yaml_emitter_write_indicator(emitter, c, true, false, false) { - return false - } - return yaml_emitter_write_anchor(emitter, emitter.anchor_data.anchor) -} - -// Write a tag. -func yaml_emitter_process_tag(emitter *yaml_emitter_t) bool { - if len(emitter.tag_data.handle) == 0 && len(emitter.tag_data.suffix) == 0 { - return true - } - if len(emitter.tag_data.handle) > 0 { - if !yaml_emitter_write_tag_handle(emitter, emitter.tag_data.handle) { - return false - } - if len(emitter.tag_data.suffix) > 0 { - if !yaml_emitter_write_tag_content(emitter, emitter.tag_data.suffix, false) { - return false - } - } - } else { - // [Go] Allocate these slices elsewhere. - if !yaml_emitter_write_indicator(emitter, []byte("!<"), true, false, false) { - return false - } - if !yaml_emitter_write_tag_content(emitter, emitter.tag_data.suffix, false) { - return false - } - if !yaml_emitter_write_indicator(emitter, []byte{'>'}, false, false, false) { - return false - } - } - return true -} - -// Write a scalar. -func yaml_emitter_process_scalar(emitter *yaml_emitter_t) bool { - switch emitter.scalar_data.style { - case yaml_PLAIN_SCALAR_STYLE: - return yaml_emitter_write_plain_scalar(emitter, emitter.scalar_data.value, !emitter.simple_key_context) - - case yaml_SINGLE_QUOTED_SCALAR_STYLE: - return yaml_emitter_write_single_quoted_scalar(emitter, emitter.scalar_data.value, !emitter.simple_key_context) - - case yaml_DOUBLE_QUOTED_SCALAR_STYLE: - return yaml_emitter_write_double_quoted_scalar(emitter, emitter.scalar_data.value, !emitter.simple_key_context) - - case yaml_LITERAL_SCALAR_STYLE: - return yaml_emitter_write_literal_scalar(emitter, emitter.scalar_data.value) - - case yaml_FOLDED_SCALAR_STYLE: - return yaml_emitter_write_folded_scalar(emitter, emitter.scalar_data.value) - } - panic("unknown scalar style") -} - -// Check if a %YAML directive is valid. -func yaml_emitter_analyze_version_directive(emitter *yaml_emitter_t, version_directive *yaml_version_directive_t) bool { - if version_directive.major != 1 || version_directive.minor != 1 { - return yaml_emitter_set_emitter_error(emitter, "incompatible %YAML directive") - } - return true -} - -// Check if a %TAG directive is valid. -func yaml_emitter_analyze_tag_directive(emitter *yaml_emitter_t, tag_directive *yaml_tag_directive_t) bool { - handle := tag_directive.handle - prefix := tag_directive.prefix - if len(handle) == 0 { - return yaml_emitter_set_emitter_error(emitter, "tag handle must not be empty") - } - if handle[0] != '!' { - return yaml_emitter_set_emitter_error(emitter, "tag handle must start with '!'") - } - if handle[len(handle)-1] != '!' { - return yaml_emitter_set_emitter_error(emitter, "tag handle must end with '!'") - } - for i := 1; i < len(handle)-1; i += width(handle[i]) { - if !is_alpha(handle, i) { - return yaml_emitter_set_emitter_error(emitter, "tag handle must contain alphanumerical characters only") - } - } - if len(prefix) == 0 { - return yaml_emitter_set_emitter_error(emitter, "tag prefix must not be empty") - } - return true -} - -// Check if an anchor is valid. -func yaml_emitter_analyze_anchor(emitter *yaml_emitter_t, anchor []byte, alias bool) bool { - if len(anchor) == 0 { - problem := "anchor value must not be empty" - if alias { - problem = "alias value must not be empty" - } - return yaml_emitter_set_emitter_error(emitter, problem) - } - for i := 0; i < len(anchor); i += width(anchor[i]) { - if !is_alpha(anchor, i) { - problem := "anchor value must contain alphanumerical characters only" - if alias { - problem = "alias value must contain alphanumerical characters only" - } - return yaml_emitter_set_emitter_error(emitter, problem) - } - } - emitter.anchor_data.anchor = anchor - emitter.anchor_data.alias = alias - return true -} - -// Check if a tag is valid. -func yaml_emitter_analyze_tag(emitter *yaml_emitter_t, tag []byte) bool { - if len(tag) == 0 { - return yaml_emitter_set_emitter_error(emitter, "tag value must not be empty") - } - for i := 0; i < len(emitter.tag_directives); i++ { - tag_directive := &emitter.tag_directives[i] - if bytes.HasPrefix(tag, tag_directive.prefix) { - emitter.tag_data.handle = tag_directive.handle - emitter.tag_data.suffix = tag[len(tag_directive.prefix):] - return true - } - } - emitter.tag_data.suffix = tag - return true -} - -// Check if a scalar is valid. -func yaml_emitter_analyze_scalar(emitter *yaml_emitter_t, value []byte) bool { - var ( - block_indicators = false - flow_indicators = false - line_breaks = false - special_characters = false - - leading_space = false - leading_break = false - trailing_space = false - trailing_break = false - break_space = false - space_break = false - - preceded_by_whitespace = false - followed_by_whitespace = false - previous_space = false - previous_break = false - ) - - emitter.scalar_data.value = value - - if len(value) == 0 { - emitter.scalar_data.multiline = false - emitter.scalar_data.flow_plain_allowed = false - emitter.scalar_data.block_plain_allowed = true - emitter.scalar_data.single_quoted_allowed = true - emitter.scalar_data.block_allowed = false - return true - } - - if len(value) >= 3 && ((value[0] == '-' && value[1] == '-' && value[2] == '-') || (value[0] == '.' && value[1] == '.' && value[2] == '.')) { - block_indicators = true - flow_indicators = true - } - - preceded_by_whitespace = true - for i, w := 0, 0; i < len(value); i += w { - w = width(value[i]) - followed_by_whitespace = i+w >= len(value) || is_blank(value, i+w) - - if i == 0 { - switch value[i] { - case '#', ',', '[', ']', '{', '}', '&', '*', '!', '|', '>', '\'', '"', '%', '@', '`': - flow_indicators = true - block_indicators = true - case '?', ':': - flow_indicators = true - if followed_by_whitespace { - block_indicators = true - } - case '-': - if followed_by_whitespace { - flow_indicators = true - block_indicators = true - } - } - } else { - switch value[i] { - case ',', '?', '[', ']', '{', '}': - flow_indicators = true - case ':': - flow_indicators = true - if followed_by_whitespace { - block_indicators = true - } - case '#': - if preceded_by_whitespace { - flow_indicators = true - block_indicators = true - } - } - } - - if !is_printable(value, i) || !is_ascii(value, i) && !emitter.unicode { - special_characters = true - } - if is_space(value, i) { - if i == 0 { - leading_space = true - } - if i+width(value[i]) == len(value) { - trailing_space = true - } - if previous_break { - break_space = true - } - previous_space = true - previous_break = false - } else if is_break(value, i) { - line_breaks = true - if i == 0 { - leading_break = true - } - if i+width(value[i]) == len(value) { - trailing_break = true - } - if previous_space { - space_break = true - } - previous_space = false - previous_break = true - } else { - previous_space = false - previous_break = false - } - - // [Go]: Why 'z'? Couldn't be the end of the string as that's the loop condition. - preceded_by_whitespace = is_blankz(value, i) - } - - emitter.scalar_data.multiline = line_breaks - emitter.scalar_data.flow_plain_allowed = true - emitter.scalar_data.block_plain_allowed = true - emitter.scalar_data.single_quoted_allowed = true - emitter.scalar_data.block_allowed = true - - if leading_space || leading_break || trailing_space || trailing_break { - emitter.scalar_data.flow_plain_allowed = false - emitter.scalar_data.block_plain_allowed = false - } - if trailing_space { - emitter.scalar_data.block_allowed = false - } - if break_space { - emitter.scalar_data.flow_plain_allowed = false - emitter.scalar_data.block_plain_allowed = false - emitter.scalar_data.single_quoted_allowed = false - } - if space_break || special_characters { - emitter.scalar_data.flow_plain_allowed = false - emitter.scalar_data.block_plain_allowed = false - emitter.scalar_data.single_quoted_allowed = false - emitter.scalar_data.block_allowed = false - } - if line_breaks { - emitter.scalar_data.flow_plain_allowed = false - emitter.scalar_data.block_plain_allowed = false - } - if flow_indicators { - emitter.scalar_data.flow_plain_allowed = false - } - if block_indicators { - emitter.scalar_data.block_plain_allowed = false - } - return true -} - -// Check if the event data is valid. -func yaml_emitter_analyze_event(emitter *yaml_emitter_t, event *yaml_event_t) bool { - - emitter.anchor_data.anchor = nil - emitter.tag_data.handle = nil - emitter.tag_data.suffix = nil - emitter.scalar_data.value = nil - - switch event.typ { - case yaml_ALIAS_EVENT: - if !yaml_emitter_analyze_anchor(emitter, event.anchor, true) { - return false - } - - case yaml_SCALAR_EVENT: - if len(event.anchor) > 0 { - if !yaml_emitter_analyze_anchor(emitter, event.anchor, false) { - return false - } - } - if len(event.tag) > 0 && (emitter.canonical || (!event.implicit && !event.quoted_implicit)) { - if !yaml_emitter_analyze_tag(emitter, event.tag) { - return false - } - } - if !yaml_emitter_analyze_scalar(emitter, event.value) { - return false - } - - case yaml_SEQUENCE_START_EVENT: - if len(event.anchor) > 0 { - if !yaml_emitter_analyze_anchor(emitter, event.anchor, false) { - return false - } - } - if len(event.tag) > 0 && (emitter.canonical || !event.implicit) { - if !yaml_emitter_analyze_tag(emitter, event.tag) { - return false - } - } - - case yaml_MAPPING_START_EVENT: - if len(event.anchor) > 0 { - if !yaml_emitter_analyze_anchor(emitter, event.anchor, false) { - return false - } - } - if len(event.tag) > 0 && (emitter.canonical || !event.implicit) { - if !yaml_emitter_analyze_tag(emitter, event.tag) { - return false - } - } - } - return true -} - -// Write the BOM character. -func yaml_emitter_write_bom(emitter *yaml_emitter_t) bool { - if !flush(emitter) { - return false - } - pos := emitter.buffer_pos - emitter.buffer[pos+0] = '\xEF' - emitter.buffer[pos+1] = '\xBB' - emitter.buffer[pos+2] = '\xBF' - emitter.buffer_pos += 3 - return true -} - -func yaml_emitter_write_indent(emitter *yaml_emitter_t) bool { - indent := emitter.indent - if indent < 0 { - indent = 0 - } - if !emitter.indention || emitter.column > indent || (emitter.column == indent && !emitter.whitespace) { - if !put_break(emitter) { - return false - } - } - for emitter.column < indent { - if !put(emitter, ' ') { - return false - } - } - emitter.whitespace = true - emitter.indention = true - return true -} - -func yaml_emitter_write_indicator(emitter *yaml_emitter_t, indicator []byte, need_whitespace, is_whitespace, is_indention bool) bool { - if need_whitespace && !emitter.whitespace { - if !put(emitter, ' ') { - return false - } - } - if !write_all(emitter, indicator) { - return false - } - emitter.whitespace = is_whitespace - emitter.indention = (emitter.indention && is_indention) - emitter.open_ended = false - return true -} - -func yaml_emitter_write_anchor(emitter *yaml_emitter_t, value []byte) bool { - if !write_all(emitter, value) { - return false - } - emitter.whitespace = false - emitter.indention = false - return true -} - -func yaml_emitter_write_tag_handle(emitter *yaml_emitter_t, value []byte) bool { - if !emitter.whitespace { - if !put(emitter, ' ') { - return false - } - } - if !write_all(emitter, value) { - return false - } - emitter.whitespace = false - emitter.indention = false - return true -} - -func yaml_emitter_write_tag_content(emitter *yaml_emitter_t, value []byte, need_whitespace bool) bool { - if need_whitespace && !emitter.whitespace { - if !put(emitter, ' ') { - return false - } - } - for i := 0; i < len(value); { - var must_write bool - switch value[i] { - case ';', '/', '?', ':', '@', '&', '=', '+', '$', ',', '_', '.', '~', '*', '\'', '(', ')', '[', ']': - must_write = true - default: - must_write = is_alpha(value, i) - } - if must_write { - if !write(emitter, value, &i) { - return false - } - } else { - w := width(value[i]) - for k := 0; k < w; k++ { - octet := value[i] - i++ - if !put(emitter, '%') { - return false - } - - c := octet >> 4 - if c < 10 { - c += '0' - } else { - c += 'A' - 10 - } - if !put(emitter, c) { - return false - } - - c = octet & 0x0f - if c < 10 { - c += '0' - } else { - c += 'A' - 10 - } - if !put(emitter, c) { - return false - } - } - } - } - emitter.whitespace = false - emitter.indention = false - return true -} - -func yaml_emitter_write_plain_scalar(emitter *yaml_emitter_t, value []byte, allow_breaks bool) bool { - if !emitter.whitespace { - if !put(emitter, ' ') { - return false - } - } - - spaces := false - breaks := false - for i := 0; i < len(value); { - if is_space(value, i) { - if allow_breaks && !spaces && emitter.column > emitter.best_width && !is_space(value, i+1) { - if !yaml_emitter_write_indent(emitter) { - return false - } - i += width(value[i]) - } else { - if !write(emitter, value, &i) { - return false - } - } - spaces = true - } else if is_break(value, i) { - if !breaks && value[i] == '\n' { - if !put_break(emitter) { - return false - } - } - if !write_break(emitter, value, &i) { - return false - } - emitter.indention = true - breaks = true - } else { - if breaks { - if !yaml_emitter_write_indent(emitter) { - return false - } - } - if !write(emitter, value, &i) { - return false - } - emitter.indention = false - spaces = false - breaks = false - } - } - - emitter.whitespace = false - emitter.indention = false - if emitter.root_context { - emitter.open_ended = true - } - - return true -} - -func yaml_emitter_write_single_quoted_scalar(emitter *yaml_emitter_t, value []byte, allow_breaks bool) bool { - - if !yaml_emitter_write_indicator(emitter, []byte{'\''}, true, false, false) { - return false - } - - spaces := false - breaks := false - for i := 0; i < len(value); { - if is_space(value, i) { - if allow_breaks && !spaces && emitter.column > emitter.best_width && i > 0 && i < len(value)-1 && !is_space(value, i+1) { - if !yaml_emitter_write_indent(emitter) { - return false - } - i += width(value[i]) - } else { - if !write(emitter, value, &i) { - return false - } - } - spaces = true - } else if is_break(value, i) { - if !breaks && value[i] == '\n' { - if !put_break(emitter) { - return false - } - } - if !write_break(emitter, value, &i) { - return false - } - emitter.indention = true - breaks = true - } else { - if breaks { - if !yaml_emitter_write_indent(emitter) { - return false - } - } - if value[i] == '\'' { - if !put(emitter, '\'') { - return false - } - } - if !write(emitter, value, &i) { - return false - } - emitter.indention = false - spaces = false - breaks = false - } - } - if !yaml_emitter_write_indicator(emitter, []byte{'\''}, false, false, false) { - return false - } - emitter.whitespace = false - emitter.indention = false - return true -} - -func yaml_emitter_write_double_quoted_scalar(emitter *yaml_emitter_t, value []byte, allow_breaks bool) bool { - spaces := false - if !yaml_emitter_write_indicator(emitter, []byte{'"'}, true, false, false) { - return false - } - - for i := 0; i < len(value); { - if !is_printable(value, i) || (!emitter.unicode && !is_ascii(value, i)) || - is_bom(value, i) || is_break(value, i) || - value[i] == '"' || value[i] == '\\' { - - octet := value[i] - - var w int - var v rune - switch { - case octet&0x80 == 0x00: - w, v = 1, rune(octet&0x7F) - case octet&0xE0 == 0xC0: - w, v = 2, rune(octet&0x1F) - case octet&0xF0 == 0xE0: - w, v = 3, rune(octet&0x0F) - case octet&0xF8 == 0xF0: - w, v = 4, rune(octet&0x07) - } - for k := 1; k < w; k++ { - octet = value[i+k] - v = (v << 6) + (rune(octet) & 0x3F) - } - i += w - - if !put(emitter, '\\') { - return false - } - - var ok bool - switch v { - case 0x00: - ok = put(emitter, '0') - case 0x07: - ok = put(emitter, 'a') - case 0x08: - ok = put(emitter, 'b') - case 0x09: - ok = put(emitter, 't') - case 0x0A: - ok = put(emitter, 'n') - case 0x0b: - ok = put(emitter, 'v') - case 0x0c: - ok = put(emitter, 'f') - case 0x0d: - ok = put(emitter, 'r') - case 0x1b: - ok = put(emitter, 'e') - case 0x22: - ok = put(emitter, '"') - case 0x5c: - ok = put(emitter, '\\') - case 0x85: - ok = put(emitter, 'N') - case 0xA0: - ok = put(emitter, '_') - case 0x2028: - ok = put(emitter, 'L') - case 0x2029: - ok = put(emitter, 'P') - default: - if v <= 0xFF { - ok = put(emitter, 'x') - w = 2 - } else if v <= 0xFFFF { - ok = put(emitter, 'u') - w = 4 - } else { - ok = put(emitter, 'U') - w = 8 - } - for k := (w - 1) * 4; ok && k >= 0; k -= 4 { - digit := byte((v >> uint(k)) & 0x0F) - if digit < 10 { - ok = put(emitter, digit+'0') - } else { - ok = put(emitter, digit+'A'-10) - } - } - } - if !ok { - return false - } - spaces = false - } else if is_space(value, i) { - if allow_breaks && !spaces && emitter.column > emitter.best_width && i > 0 && i < len(value)-1 { - if !yaml_emitter_write_indent(emitter) { - return false - } - if is_space(value, i+1) { - if !put(emitter, '\\') { - return false - } - } - i += width(value[i]) - } else if !write(emitter, value, &i) { - return false - } - spaces = true - } else { - if !write(emitter, value, &i) { - return false - } - spaces = false - } - } - if !yaml_emitter_write_indicator(emitter, []byte{'"'}, false, false, false) { - return false - } - emitter.whitespace = false - emitter.indention = false - return true -} - -func yaml_emitter_write_block_scalar_hints(emitter *yaml_emitter_t, value []byte) bool { - if is_space(value, 0) || is_break(value, 0) { - indent_hint := []byte{'0' + byte(emitter.best_indent)} - if !yaml_emitter_write_indicator(emitter, indent_hint, false, false, false) { - return false - } - } - - emitter.open_ended = false - - var chomp_hint [1]byte - if len(value) == 0 { - chomp_hint[0] = '-' - } else { - i := len(value) - 1 - for value[i]&0xC0 == 0x80 { - i-- - } - if !is_break(value, i) { - chomp_hint[0] = '-' - } else if i == 0 { - chomp_hint[0] = '+' - emitter.open_ended = true - } else { - i-- - for value[i]&0xC0 == 0x80 { - i-- - } - if is_break(value, i) { - chomp_hint[0] = '+' - emitter.open_ended = true - } - } - } - if chomp_hint[0] != 0 { - if !yaml_emitter_write_indicator(emitter, chomp_hint[:], false, false, false) { - return false - } - } - return true -} - -func yaml_emitter_write_literal_scalar(emitter *yaml_emitter_t, value []byte) bool { - if !yaml_emitter_write_indicator(emitter, []byte{'|'}, true, false, false) { - return false - } - if !yaml_emitter_write_block_scalar_hints(emitter, value) { - return false - } - if !put_break(emitter) { - return false - } - emitter.indention = true - emitter.whitespace = true - breaks := true - for i := 0; i < len(value); { - if is_break(value, i) { - if !write_break(emitter, value, &i) { - return false - } - emitter.indention = true - breaks = true - } else { - if breaks { - if !yaml_emitter_write_indent(emitter) { - return false - } - } - if !write(emitter, value, &i) { - return false - } - emitter.indention = false - breaks = false - } - } - - return true -} - -func yaml_emitter_write_folded_scalar(emitter *yaml_emitter_t, value []byte) bool { - if !yaml_emitter_write_indicator(emitter, []byte{'>'}, true, false, false) { - return false - } - if !yaml_emitter_write_block_scalar_hints(emitter, value) { - return false - } - - if !put_break(emitter) { - return false - } - emitter.indention = true - emitter.whitespace = true - - breaks := true - leading_spaces := true - for i := 0; i < len(value); { - if is_break(value, i) { - if !breaks && !leading_spaces && value[i] == '\n' { - k := 0 - for is_break(value, k) { - k += width(value[k]) - } - if !is_blankz(value, k) { - if !put_break(emitter) { - return false - } - } - } - if !write_break(emitter, value, &i) { - return false - } - emitter.indention = true - breaks = true - } else { - if breaks { - if !yaml_emitter_write_indent(emitter) { - return false - } - leading_spaces = is_blank(value, i) - } - if !breaks && is_space(value, i) && !is_space(value, i+1) && emitter.column > emitter.best_width { - if !yaml_emitter_write_indent(emitter) { - return false - } - i += width(value[i]) - } else { - if !write(emitter, value, &i) { - return false - } - } - emitter.indention = false - breaks = false - } - } - return true -} diff --git a/vendor/gopkg.in/yaml.v2/encode.go b/vendor/gopkg.in/yaml.v2/encode.go deleted file mode 100644 index a14435e82f..0000000000 --- a/vendor/gopkg.in/yaml.v2/encode.go +++ /dev/null @@ -1,362 +0,0 @@ -package yaml - -import ( - "encoding" - "fmt" - "io" - "reflect" - "regexp" - "sort" - "strconv" - "strings" - "time" - "unicode/utf8" -) - -type encoder struct { - emitter yaml_emitter_t - event yaml_event_t - out []byte - flow bool - // doneInit holds whether the initial stream_start_event has been - // emitted. - doneInit bool -} - -func newEncoder() *encoder { - e := &encoder{} - yaml_emitter_initialize(&e.emitter) - yaml_emitter_set_output_string(&e.emitter, &e.out) - yaml_emitter_set_unicode(&e.emitter, true) - return e -} - -func newEncoderWithWriter(w io.Writer) *encoder { - e := &encoder{} - yaml_emitter_initialize(&e.emitter) - yaml_emitter_set_output_writer(&e.emitter, w) - yaml_emitter_set_unicode(&e.emitter, true) - return e -} - -func (e *encoder) init() { - if e.doneInit { - return - } - yaml_stream_start_event_initialize(&e.event, yaml_UTF8_ENCODING) - e.emit() - e.doneInit = true -} - -func (e *encoder) finish() { - e.emitter.open_ended = false - yaml_stream_end_event_initialize(&e.event) - e.emit() -} - -func (e *encoder) destroy() { - yaml_emitter_delete(&e.emitter) -} - -func (e *encoder) emit() { - // This will internally delete the e.event value. - e.must(yaml_emitter_emit(&e.emitter, &e.event)) -} - -func (e *encoder) must(ok bool) { - if !ok { - msg := e.emitter.problem - if msg == "" { - msg = "unknown problem generating YAML content" - } - failf("%s", msg) - } -} - -func (e *encoder) marshalDoc(tag string, in reflect.Value) { - e.init() - yaml_document_start_event_initialize(&e.event, nil, nil, true) - e.emit() - e.marshal(tag, in) - yaml_document_end_event_initialize(&e.event, true) - e.emit() -} - -func (e *encoder) marshal(tag string, in reflect.Value) { - if !in.IsValid() || in.Kind() == reflect.Ptr && in.IsNil() { - e.nilv() - return - } - iface := in.Interface() - switch m := iface.(type) { - case time.Time, *time.Time: - // Although time.Time implements TextMarshaler, - // we don't want to treat it as a string for YAML - // purposes because YAML has special support for - // timestamps. - case Marshaler: - v, err := m.MarshalYAML() - if err != nil { - fail(err) - } - if v == nil { - e.nilv() - return - } - in = reflect.ValueOf(v) - case encoding.TextMarshaler: - text, err := m.MarshalText() - if err != nil { - fail(err) - } - in = reflect.ValueOf(string(text)) - case nil: - e.nilv() - return - } - switch in.Kind() { - case reflect.Interface: - e.marshal(tag, in.Elem()) - case reflect.Map: - e.mapv(tag, in) - case reflect.Ptr: - if in.Type() == ptrTimeType { - e.timev(tag, in.Elem()) - } else { - e.marshal(tag, in.Elem()) - } - case reflect.Struct: - if in.Type() == timeType { - e.timev(tag, in) - } else { - e.structv(tag, in) - } - case reflect.Slice, reflect.Array: - if in.Type().Elem() == mapItemType { - e.itemsv(tag, in) - } else { - e.slicev(tag, in) - } - case reflect.String: - e.stringv(tag, in) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if in.Type() == durationType { - e.stringv(tag, reflect.ValueOf(iface.(time.Duration).String())) - } else { - e.intv(tag, in) - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - e.uintv(tag, in) - case reflect.Float32, reflect.Float64: - e.floatv(tag, in) - case reflect.Bool: - e.boolv(tag, in) - default: - panic("cannot marshal type: " + in.Type().String()) - } -} - -func (e *encoder) mapv(tag string, in reflect.Value) { - e.mappingv(tag, func() { - keys := keyList(in.MapKeys()) - sort.Sort(keys) - for _, k := range keys { - e.marshal("", k) - e.marshal("", in.MapIndex(k)) - } - }) -} - -func (e *encoder) itemsv(tag string, in reflect.Value) { - e.mappingv(tag, func() { - slice := in.Convert(reflect.TypeOf([]MapItem{})).Interface().([]MapItem) - for _, item := range slice { - e.marshal("", reflect.ValueOf(item.Key)) - e.marshal("", reflect.ValueOf(item.Value)) - } - }) -} - -func (e *encoder) structv(tag string, in reflect.Value) { - sinfo, err := getStructInfo(in.Type()) - if err != nil { - panic(err) - } - e.mappingv(tag, func() { - for _, info := range sinfo.FieldsList { - var value reflect.Value - if info.Inline == nil { - value = in.Field(info.Num) - } else { - value = in.FieldByIndex(info.Inline) - } - if info.OmitEmpty && isZero(value) { - continue - } - e.marshal("", reflect.ValueOf(info.Key)) - e.flow = info.Flow - e.marshal("", value) - } - if sinfo.InlineMap >= 0 { - m := in.Field(sinfo.InlineMap) - if m.Len() > 0 { - e.flow = false - keys := keyList(m.MapKeys()) - sort.Sort(keys) - for _, k := range keys { - if _, found := sinfo.FieldsMap[k.String()]; found { - panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", k.String())) - } - e.marshal("", k) - e.flow = false - e.marshal("", m.MapIndex(k)) - } - } - } - }) -} - -func (e *encoder) mappingv(tag string, f func()) { - implicit := tag == "" - style := yaml_BLOCK_MAPPING_STYLE - if e.flow { - e.flow = false - style = yaml_FLOW_MAPPING_STYLE - } - yaml_mapping_start_event_initialize(&e.event, nil, []byte(tag), implicit, style) - e.emit() - f() - yaml_mapping_end_event_initialize(&e.event) - e.emit() -} - -func (e *encoder) slicev(tag string, in reflect.Value) { - implicit := tag == "" - style := yaml_BLOCK_SEQUENCE_STYLE - if e.flow { - e.flow = false - style = yaml_FLOW_SEQUENCE_STYLE - } - e.must(yaml_sequence_start_event_initialize(&e.event, nil, []byte(tag), implicit, style)) - e.emit() - n := in.Len() - for i := 0; i < n; i++ { - e.marshal("", in.Index(i)) - } - e.must(yaml_sequence_end_event_initialize(&e.event)) - e.emit() -} - -// isBase60 returns whether s is in base 60 notation as defined in YAML 1.1. -// -// The base 60 float notation in YAML 1.1 is a terrible idea and is unsupported -// in YAML 1.2 and by this package, but these should be marshalled quoted for -// the time being for compatibility with other parsers. -func isBase60Float(s string) (result bool) { - // Fast path. - if s == "" { - return false - } - c := s[0] - if !(c == '+' || c == '-' || c >= '0' && c <= '9') || strings.IndexByte(s, ':') < 0 { - return false - } - // Do the full match. - return base60float.MatchString(s) -} - -// From http://yaml.org/type/float.html, except the regular expression there -// is bogus. In practice parsers do not enforce the "\.[0-9_]*" suffix. -var base60float = regexp.MustCompile(`^[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+(?:\.[0-9_]*)?$`) - -func (e *encoder) stringv(tag string, in reflect.Value) { - var style yaml_scalar_style_t - s := in.String() - canUsePlain := true - switch { - case !utf8.ValidString(s): - if tag == yaml_BINARY_TAG { - failf("explicitly tagged !!binary data must be base64-encoded") - } - if tag != "" { - failf("cannot marshal invalid UTF-8 data as %s", shortTag(tag)) - } - // It can't be encoded directly as YAML so use a binary tag - // and encode it as base64. - tag = yaml_BINARY_TAG - s = encodeBase64(s) - case tag == "": - // Check to see if it would resolve to a specific - // tag when encoded unquoted. If it doesn't, - // there's no need to quote it. - rtag, _ := resolve("", s) - canUsePlain = rtag == yaml_STR_TAG && !isBase60Float(s) - } - // Note: it's possible for user code to emit invalid YAML - // if they explicitly specify a tag and a string containing - // text that's incompatible with that tag. - switch { - case strings.Contains(s, "\n"): - style = yaml_LITERAL_SCALAR_STYLE - case canUsePlain: - style = yaml_PLAIN_SCALAR_STYLE - default: - style = yaml_DOUBLE_QUOTED_SCALAR_STYLE - } - e.emitScalar(s, "", tag, style) -} - -func (e *encoder) boolv(tag string, in reflect.Value) { - var s string - if in.Bool() { - s = "true" - } else { - s = "false" - } - e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE) -} - -func (e *encoder) intv(tag string, in reflect.Value) { - s := strconv.FormatInt(in.Int(), 10) - e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE) -} - -func (e *encoder) uintv(tag string, in reflect.Value) { - s := strconv.FormatUint(in.Uint(), 10) - e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE) -} - -func (e *encoder) timev(tag string, in reflect.Value) { - t := in.Interface().(time.Time) - s := t.Format(time.RFC3339Nano) - e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE) -} - -func (e *encoder) floatv(tag string, in reflect.Value) { - // Issue #352: When formatting, use the precision of the underlying value - precision := 64 - if in.Kind() == reflect.Float32 { - precision = 32 - } - - s := strconv.FormatFloat(in.Float(), 'g', -1, precision) - switch s { - case "+Inf": - s = ".inf" - case "-Inf": - s = "-.inf" - case "NaN": - s = ".nan" - } - e.emitScalar(s, "", tag, yaml_PLAIN_SCALAR_STYLE) -} - -func (e *encoder) nilv() { - e.emitScalar("null", "", "", yaml_PLAIN_SCALAR_STYLE) -} - -func (e *encoder) emitScalar(value, anchor, tag string, style yaml_scalar_style_t) { - implicit := tag == "" - e.must(yaml_scalar_event_initialize(&e.event, []byte(anchor), []byte(tag), []byte(value), implicit, implicit, style)) - e.emit() -} diff --git a/vendor/gopkg.in/yaml.v2/go.mod b/vendor/gopkg.in/yaml.v2/go.mod deleted file mode 100644 index 1934e87694..0000000000 --- a/vendor/gopkg.in/yaml.v2/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module "gopkg.in/yaml.v2" - -require ( - "gopkg.in/check.v1" v0.0.0-20161208181325-20d25e280405 -) diff --git a/vendor/gopkg.in/yaml.v2/parserc.go b/vendor/gopkg.in/yaml.v2/parserc.go deleted file mode 100644 index 81d05dfe57..0000000000 --- a/vendor/gopkg.in/yaml.v2/parserc.go +++ /dev/null @@ -1,1095 +0,0 @@ -package yaml - -import ( - "bytes" -) - -// The parser implements the following grammar: -// -// stream ::= STREAM-START implicit_document? explicit_document* STREAM-END -// implicit_document ::= block_node DOCUMENT-END* -// explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* -// block_node_or_indentless_sequence ::= -// ALIAS -// | properties (block_content | indentless_block_sequence)? -// | block_content -// | indentless_block_sequence -// block_node ::= ALIAS -// | properties block_content? -// | block_content -// flow_node ::= ALIAS -// | properties flow_content? -// | flow_content -// properties ::= TAG ANCHOR? | ANCHOR TAG? -// block_content ::= block_collection | flow_collection | SCALAR -// flow_content ::= flow_collection | SCALAR -// block_collection ::= block_sequence | block_mapping -// flow_collection ::= flow_sequence | flow_mapping -// block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END -// indentless_sequence ::= (BLOCK-ENTRY block_node?)+ -// block_mapping ::= BLOCK-MAPPING_START -// ((KEY block_node_or_indentless_sequence?)? -// (VALUE block_node_or_indentless_sequence?)?)* -// BLOCK-END -// flow_sequence ::= FLOW-SEQUENCE-START -// (flow_sequence_entry FLOW-ENTRY)* -// flow_sequence_entry? -// FLOW-SEQUENCE-END -// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -// flow_mapping ::= FLOW-MAPPING-START -// (flow_mapping_entry FLOW-ENTRY)* -// flow_mapping_entry? -// FLOW-MAPPING-END -// flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? - -// Peek the next token in the token queue. -func peek_token(parser *yaml_parser_t) *yaml_token_t { - if parser.token_available || yaml_parser_fetch_more_tokens(parser) { - return &parser.tokens[parser.tokens_head] - } - return nil -} - -// Remove the next token from the queue (must be called after peek_token). -func skip_token(parser *yaml_parser_t) { - parser.token_available = false - parser.tokens_parsed++ - parser.stream_end_produced = parser.tokens[parser.tokens_head].typ == yaml_STREAM_END_TOKEN - parser.tokens_head++ -} - -// Get the next event. -func yaml_parser_parse(parser *yaml_parser_t, event *yaml_event_t) bool { - // Erase the event object. - *event = yaml_event_t{} - - // No events after the end of the stream or error. - if parser.stream_end_produced || parser.error != yaml_NO_ERROR || parser.state == yaml_PARSE_END_STATE { - return true - } - - // Generate the next event. - return yaml_parser_state_machine(parser, event) -} - -// Set parser error. -func yaml_parser_set_parser_error(parser *yaml_parser_t, problem string, problem_mark yaml_mark_t) bool { - parser.error = yaml_PARSER_ERROR - parser.problem = problem - parser.problem_mark = problem_mark - return false -} - -func yaml_parser_set_parser_error_context(parser *yaml_parser_t, context string, context_mark yaml_mark_t, problem string, problem_mark yaml_mark_t) bool { - parser.error = yaml_PARSER_ERROR - parser.context = context - parser.context_mark = context_mark - parser.problem = problem - parser.problem_mark = problem_mark - return false -} - -// State dispatcher. -func yaml_parser_state_machine(parser *yaml_parser_t, event *yaml_event_t) bool { - //trace("yaml_parser_state_machine", "state:", parser.state.String()) - - switch parser.state { - case yaml_PARSE_STREAM_START_STATE: - return yaml_parser_parse_stream_start(parser, event) - - case yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE: - return yaml_parser_parse_document_start(parser, event, true) - - case yaml_PARSE_DOCUMENT_START_STATE: - return yaml_parser_parse_document_start(parser, event, false) - - case yaml_PARSE_DOCUMENT_CONTENT_STATE: - return yaml_parser_parse_document_content(parser, event) - - case yaml_PARSE_DOCUMENT_END_STATE: - return yaml_parser_parse_document_end(parser, event) - - case yaml_PARSE_BLOCK_NODE_STATE: - return yaml_parser_parse_node(parser, event, true, false) - - case yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE: - return yaml_parser_parse_node(parser, event, true, true) - - case yaml_PARSE_FLOW_NODE_STATE: - return yaml_parser_parse_node(parser, event, false, false) - - case yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE: - return yaml_parser_parse_block_sequence_entry(parser, event, true) - - case yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE: - return yaml_parser_parse_block_sequence_entry(parser, event, false) - - case yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE: - return yaml_parser_parse_indentless_sequence_entry(parser, event) - - case yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE: - return yaml_parser_parse_block_mapping_key(parser, event, true) - - case yaml_PARSE_BLOCK_MAPPING_KEY_STATE: - return yaml_parser_parse_block_mapping_key(parser, event, false) - - case yaml_PARSE_BLOCK_MAPPING_VALUE_STATE: - return yaml_parser_parse_block_mapping_value(parser, event) - - case yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE: - return yaml_parser_parse_flow_sequence_entry(parser, event, true) - - case yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE: - return yaml_parser_parse_flow_sequence_entry(parser, event, false) - - case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE: - return yaml_parser_parse_flow_sequence_entry_mapping_key(parser, event) - - case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE: - return yaml_parser_parse_flow_sequence_entry_mapping_value(parser, event) - - case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE: - return yaml_parser_parse_flow_sequence_entry_mapping_end(parser, event) - - case yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE: - return yaml_parser_parse_flow_mapping_key(parser, event, true) - - case yaml_PARSE_FLOW_MAPPING_KEY_STATE: - return yaml_parser_parse_flow_mapping_key(parser, event, false) - - case yaml_PARSE_FLOW_MAPPING_VALUE_STATE: - return yaml_parser_parse_flow_mapping_value(parser, event, false) - - case yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE: - return yaml_parser_parse_flow_mapping_value(parser, event, true) - - default: - panic("invalid parser state") - } -} - -// Parse the production: -// stream ::= STREAM-START implicit_document? explicit_document* STREAM-END -// ************ -func yaml_parser_parse_stream_start(parser *yaml_parser_t, event *yaml_event_t) bool { - token := peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_STREAM_START_TOKEN { - return yaml_parser_set_parser_error(parser, "did not find expected ", token.start_mark) - } - parser.state = yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE - *event = yaml_event_t{ - typ: yaml_STREAM_START_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - encoding: token.encoding, - } - skip_token(parser) - return true -} - -// Parse the productions: -// implicit_document ::= block_node DOCUMENT-END* -// * -// explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* -// ************************* -func yaml_parser_parse_document_start(parser *yaml_parser_t, event *yaml_event_t, implicit bool) bool { - - token := peek_token(parser) - if token == nil { - return false - } - - // Parse extra document end indicators. - if !implicit { - for token.typ == yaml_DOCUMENT_END_TOKEN { - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - } - } - - if implicit && token.typ != yaml_VERSION_DIRECTIVE_TOKEN && - token.typ != yaml_TAG_DIRECTIVE_TOKEN && - token.typ != yaml_DOCUMENT_START_TOKEN && - token.typ != yaml_STREAM_END_TOKEN { - // Parse an implicit document. - if !yaml_parser_process_directives(parser, nil, nil) { - return false - } - parser.states = append(parser.states, yaml_PARSE_DOCUMENT_END_STATE) - parser.state = yaml_PARSE_BLOCK_NODE_STATE - - *event = yaml_event_t{ - typ: yaml_DOCUMENT_START_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - } - - } else if token.typ != yaml_STREAM_END_TOKEN { - // Parse an explicit document. - var version_directive *yaml_version_directive_t - var tag_directives []yaml_tag_directive_t - start_mark := token.start_mark - if !yaml_parser_process_directives(parser, &version_directive, &tag_directives) { - return false - } - token = peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_DOCUMENT_START_TOKEN { - yaml_parser_set_parser_error(parser, - "did not find expected ", token.start_mark) - return false - } - parser.states = append(parser.states, yaml_PARSE_DOCUMENT_END_STATE) - parser.state = yaml_PARSE_DOCUMENT_CONTENT_STATE - end_mark := token.end_mark - - *event = yaml_event_t{ - typ: yaml_DOCUMENT_START_EVENT, - start_mark: start_mark, - end_mark: end_mark, - version_directive: version_directive, - tag_directives: tag_directives, - implicit: false, - } - skip_token(parser) - - } else { - // Parse the stream end. - parser.state = yaml_PARSE_END_STATE - *event = yaml_event_t{ - typ: yaml_STREAM_END_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - } - skip_token(parser) - } - - return true -} - -// Parse the productions: -// explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* -// *********** -// -func yaml_parser_parse_document_content(parser *yaml_parser_t, event *yaml_event_t) bool { - token := peek_token(parser) - if token == nil { - return false - } - if token.typ == yaml_VERSION_DIRECTIVE_TOKEN || - token.typ == yaml_TAG_DIRECTIVE_TOKEN || - token.typ == yaml_DOCUMENT_START_TOKEN || - token.typ == yaml_DOCUMENT_END_TOKEN || - token.typ == yaml_STREAM_END_TOKEN { - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - return yaml_parser_process_empty_scalar(parser, event, - token.start_mark) - } - return yaml_parser_parse_node(parser, event, true, false) -} - -// Parse the productions: -// implicit_document ::= block_node DOCUMENT-END* -// ************* -// explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* -// -func yaml_parser_parse_document_end(parser *yaml_parser_t, event *yaml_event_t) bool { - token := peek_token(parser) - if token == nil { - return false - } - - start_mark := token.start_mark - end_mark := token.start_mark - - implicit := true - if token.typ == yaml_DOCUMENT_END_TOKEN { - end_mark = token.end_mark - skip_token(parser) - implicit = false - } - - parser.tag_directives = parser.tag_directives[:0] - - parser.state = yaml_PARSE_DOCUMENT_START_STATE - *event = yaml_event_t{ - typ: yaml_DOCUMENT_END_EVENT, - start_mark: start_mark, - end_mark: end_mark, - implicit: implicit, - } - return true -} - -// Parse the productions: -// block_node_or_indentless_sequence ::= -// ALIAS -// ***** -// | properties (block_content | indentless_block_sequence)? -// ********** * -// | block_content | indentless_block_sequence -// * -// block_node ::= ALIAS -// ***** -// | properties block_content? -// ********** * -// | block_content -// * -// flow_node ::= ALIAS -// ***** -// | properties flow_content? -// ********** * -// | flow_content -// * -// properties ::= TAG ANCHOR? | ANCHOR TAG? -// ************************* -// block_content ::= block_collection | flow_collection | SCALAR -// ****** -// flow_content ::= flow_collection | SCALAR -// ****** -func yaml_parser_parse_node(parser *yaml_parser_t, event *yaml_event_t, block, indentless_sequence bool) bool { - //defer trace("yaml_parser_parse_node", "block:", block, "indentless_sequence:", indentless_sequence)() - - token := peek_token(parser) - if token == nil { - return false - } - - if token.typ == yaml_ALIAS_TOKEN { - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - *event = yaml_event_t{ - typ: yaml_ALIAS_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - anchor: token.value, - } - skip_token(parser) - return true - } - - start_mark := token.start_mark - end_mark := token.start_mark - - var tag_token bool - var tag_handle, tag_suffix, anchor []byte - var tag_mark yaml_mark_t - if token.typ == yaml_ANCHOR_TOKEN { - anchor = token.value - start_mark = token.start_mark - end_mark = token.end_mark - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - if token.typ == yaml_TAG_TOKEN { - tag_token = true - tag_handle = token.value - tag_suffix = token.suffix - tag_mark = token.start_mark - end_mark = token.end_mark - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - } - } else if token.typ == yaml_TAG_TOKEN { - tag_token = true - tag_handle = token.value - tag_suffix = token.suffix - start_mark = token.start_mark - tag_mark = token.start_mark - end_mark = token.end_mark - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - if token.typ == yaml_ANCHOR_TOKEN { - anchor = token.value - end_mark = token.end_mark - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - } - } - - var tag []byte - if tag_token { - if len(tag_handle) == 0 { - tag = tag_suffix - tag_suffix = nil - } else { - for i := range parser.tag_directives { - if bytes.Equal(parser.tag_directives[i].handle, tag_handle) { - tag = append([]byte(nil), parser.tag_directives[i].prefix...) - tag = append(tag, tag_suffix...) - break - } - } - if len(tag) == 0 { - yaml_parser_set_parser_error_context(parser, - "while parsing a node", start_mark, - "found undefined tag handle", tag_mark) - return false - } - } - } - - implicit := len(tag) == 0 - if indentless_sequence && token.typ == yaml_BLOCK_ENTRY_TOKEN { - end_mark = token.end_mark - parser.state = yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE - *event = yaml_event_t{ - typ: yaml_SEQUENCE_START_EVENT, - start_mark: start_mark, - end_mark: end_mark, - anchor: anchor, - tag: tag, - implicit: implicit, - style: yaml_style_t(yaml_BLOCK_SEQUENCE_STYLE), - } - return true - } - if token.typ == yaml_SCALAR_TOKEN { - var plain_implicit, quoted_implicit bool - end_mark = token.end_mark - if (len(tag) == 0 && token.style == yaml_PLAIN_SCALAR_STYLE) || (len(tag) == 1 && tag[0] == '!') { - plain_implicit = true - } else if len(tag) == 0 { - quoted_implicit = true - } - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - - *event = yaml_event_t{ - typ: yaml_SCALAR_EVENT, - start_mark: start_mark, - end_mark: end_mark, - anchor: anchor, - tag: tag, - value: token.value, - implicit: plain_implicit, - quoted_implicit: quoted_implicit, - style: yaml_style_t(token.style), - } - skip_token(parser) - return true - } - if token.typ == yaml_FLOW_SEQUENCE_START_TOKEN { - // [Go] Some of the events below can be merged as they differ only on style. - end_mark = token.end_mark - parser.state = yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE - *event = yaml_event_t{ - typ: yaml_SEQUENCE_START_EVENT, - start_mark: start_mark, - end_mark: end_mark, - anchor: anchor, - tag: tag, - implicit: implicit, - style: yaml_style_t(yaml_FLOW_SEQUENCE_STYLE), - } - return true - } - if token.typ == yaml_FLOW_MAPPING_START_TOKEN { - end_mark = token.end_mark - parser.state = yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE - *event = yaml_event_t{ - typ: yaml_MAPPING_START_EVENT, - start_mark: start_mark, - end_mark: end_mark, - anchor: anchor, - tag: tag, - implicit: implicit, - style: yaml_style_t(yaml_FLOW_MAPPING_STYLE), - } - return true - } - if block && token.typ == yaml_BLOCK_SEQUENCE_START_TOKEN { - end_mark = token.end_mark - parser.state = yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE - *event = yaml_event_t{ - typ: yaml_SEQUENCE_START_EVENT, - start_mark: start_mark, - end_mark: end_mark, - anchor: anchor, - tag: tag, - implicit: implicit, - style: yaml_style_t(yaml_BLOCK_SEQUENCE_STYLE), - } - return true - } - if block && token.typ == yaml_BLOCK_MAPPING_START_TOKEN { - end_mark = token.end_mark - parser.state = yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE - *event = yaml_event_t{ - typ: yaml_MAPPING_START_EVENT, - start_mark: start_mark, - end_mark: end_mark, - anchor: anchor, - tag: tag, - implicit: implicit, - style: yaml_style_t(yaml_BLOCK_MAPPING_STYLE), - } - return true - } - if len(anchor) > 0 || len(tag) > 0 { - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - - *event = yaml_event_t{ - typ: yaml_SCALAR_EVENT, - start_mark: start_mark, - end_mark: end_mark, - anchor: anchor, - tag: tag, - implicit: implicit, - quoted_implicit: false, - style: yaml_style_t(yaml_PLAIN_SCALAR_STYLE), - } - return true - } - - context := "while parsing a flow node" - if block { - context = "while parsing a block node" - } - yaml_parser_set_parser_error_context(parser, context, start_mark, - "did not find expected node content", token.start_mark) - return false -} - -// Parse the productions: -// block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END -// ******************** *********** * ********* -// -func yaml_parser_parse_block_sequence_entry(parser *yaml_parser_t, event *yaml_event_t, first bool) bool { - if first { - token := peek_token(parser) - parser.marks = append(parser.marks, token.start_mark) - skip_token(parser) - } - - token := peek_token(parser) - if token == nil { - return false - } - - if token.typ == yaml_BLOCK_ENTRY_TOKEN { - mark := token.end_mark - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_BLOCK_ENTRY_TOKEN && token.typ != yaml_BLOCK_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE) - return yaml_parser_parse_node(parser, event, true, false) - } else { - parser.state = yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE - return yaml_parser_process_empty_scalar(parser, event, mark) - } - } - if token.typ == yaml_BLOCK_END_TOKEN { - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - parser.marks = parser.marks[:len(parser.marks)-1] - - *event = yaml_event_t{ - typ: yaml_SEQUENCE_END_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - } - - skip_token(parser) - return true - } - - context_mark := parser.marks[len(parser.marks)-1] - parser.marks = parser.marks[:len(parser.marks)-1] - return yaml_parser_set_parser_error_context(parser, - "while parsing a block collection", context_mark, - "did not find expected '-' indicator", token.start_mark) -} - -// Parse the productions: -// indentless_sequence ::= (BLOCK-ENTRY block_node?)+ -// *********** * -func yaml_parser_parse_indentless_sequence_entry(parser *yaml_parser_t, event *yaml_event_t) bool { - token := peek_token(parser) - if token == nil { - return false - } - - if token.typ == yaml_BLOCK_ENTRY_TOKEN { - mark := token.end_mark - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_BLOCK_ENTRY_TOKEN && - token.typ != yaml_KEY_TOKEN && - token.typ != yaml_VALUE_TOKEN && - token.typ != yaml_BLOCK_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE) - return yaml_parser_parse_node(parser, event, true, false) - } - parser.state = yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE - return yaml_parser_process_empty_scalar(parser, event, mark) - } - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - - *event = yaml_event_t{ - typ: yaml_SEQUENCE_END_EVENT, - start_mark: token.start_mark, - end_mark: token.start_mark, // [Go] Shouldn't this be token.end_mark? - } - return true -} - -// Parse the productions: -// block_mapping ::= BLOCK-MAPPING_START -// ******************* -// ((KEY block_node_or_indentless_sequence?)? -// *** * -// (VALUE block_node_or_indentless_sequence?)?)* -// -// BLOCK-END -// ********* -// -func yaml_parser_parse_block_mapping_key(parser *yaml_parser_t, event *yaml_event_t, first bool) bool { - if first { - token := peek_token(parser) - parser.marks = append(parser.marks, token.start_mark) - skip_token(parser) - } - - token := peek_token(parser) - if token == nil { - return false - } - - if token.typ == yaml_KEY_TOKEN { - mark := token.end_mark - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_KEY_TOKEN && - token.typ != yaml_VALUE_TOKEN && - token.typ != yaml_BLOCK_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_BLOCK_MAPPING_VALUE_STATE) - return yaml_parser_parse_node(parser, event, true, true) - } else { - parser.state = yaml_PARSE_BLOCK_MAPPING_VALUE_STATE - return yaml_parser_process_empty_scalar(parser, event, mark) - } - } else if token.typ == yaml_BLOCK_END_TOKEN { - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - parser.marks = parser.marks[:len(parser.marks)-1] - *event = yaml_event_t{ - typ: yaml_MAPPING_END_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - } - skip_token(parser) - return true - } - - context_mark := parser.marks[len(parser.marks)-1] - parser.marks = parser.marks[:len(parser.marks)-1] - return yaml_parser_set_parser_error_context(parser, - "while parsing a block mapping", context_mark, - "did not find expected key", token.start_mark) -} - -// Parse the productions: -// block_mapping ::= BLOCK-MAPPING_START -// -// ((KEY block_node_or_indentless_sequence?)? -// -// (VALUE block_node_or_indentless_sequence?)?)* -// ***** * -// BLOCK-END -// -// -func yaml_parser_parse_block_mapping_value(parser *yaml_parser_t, event *yaml_event_t) bool { - token := peek_token(parser) - if token == nil { - return false - } - if token.typ == yaml_VALUE_TOKEN { - mark := token.end_mark - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_KEY_TOKEN && - token.typ != yaml_VALUE_TOKEN && - token.typ != yaml_BLOCK_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_BLOCK_MAPPING_KEY_STATE) - return yaml_parser_parse_node(parser, event, true, true) - } - parser.state = yaml_PARSE_BLOCK_MAPPING_KEY_STATE - return yaml_parser_process_empty_scalar(parser, event, mark) - } - parser.state = yaml_PARSE_BLOCK_MAPPING_KEY_STATE - return yaml_parser_process_empty_scalar(parser, event, token.start_mark) -} - -// Parse the productions: -// flow_sequence ::= FLOW-SEQUENCE-START -// ******************* -// (flow_sequence_entry FLOW-ENTRY)* -// * ********** -// flow_sequence_entry? -// * -// FLOW-SEQUENCE-END -// ***************** -// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -// * -// -func yaml_parser_parse_flow_sequence_entry(parser *yaml_parser_t, event *yaml_event_t, first bool) bool { - if first { - token := peek_token(parser) - parser.marks = append(parser.marks, token.start_mark) - skip_token(parser) - } - token := peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_FLOW_SEQUENCE_END_TOKEN { - if !first { - if token.typ == yaml_FLOW_ENTRY_TOKEN { - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - } else { - context_mark := parser.marks[len(parser.marks)-1] - parser.marks = parser.marks[:len(parser.marks)-1] - return yaml_parser_set_parser_error_context(parser, - "while parsing a flow sequence", context_mark, - "did not find expected ',' or ']'", token.start_mark) - } - } - - if token.typ == yaml_KEY_TOKEN { - parser.state = yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE - *event = yaml_event_t{ - typ: yaml_MAPPING_START_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - implicit: true, - style: yaml_style_t(yaml_FLOW_MAPPING_STYLE), - } - skip_token(parser) - return true - } else if token.typ != yaml_FLOW_SEQUENCE_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE) - return yaml_parser_parse_node(parser, event, false, false) - } - } - - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - parser.marks = parser.marks[:len(parser.marks)-1] - - *event = yaml_event_t{ - typ: yaml_SEQUENCE_END_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - } - - skip_token(parser) - return true -} - -// -// Parse the productions: -// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -// *** * -// -func yaml_parser_parse_flow_sequence_entry_mapping_key(parser *yaml_parser_t, event *yaml_event_t) bool { - token := peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_VALUE_TOKEN && - token.typ != yaml_FLOW_ENTRY_TOKEN && - token.typ != yaml_FLOW_SEQUENCE_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE) - return yaml_parser_parse_node(parser, event, false, false) - } - mark := token.end_mark - skip_token(parser) - parser.state = yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE - return yaml_parser_process_empty_scalar(parser, event, mark) -} - -// Parse the productions: -// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -// ***** * -// -func yaml_parser_parse_flow_sequence_entry_mapping_value(parser *yaml_parser_t, event *yaml_event_t) bool { - token := peek_token(parser) - if token == nil { - return false - } - if token.typ == yaml_VALUE_TOKEN { - skip_token(parser) - token := peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_FLOW_ENTRY_TOKEN && token.typ != yaml_FLOW_SEQUENCE_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE) - return yaml_parser_parse_node(parser, event, false, false) - } - } - parser.state = yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE - return yaml_parser_process_empty_scalar(parser, event, token.start_mark) -} - -// Parse the productions: -// flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -// * -// -func yaml_parser_parse_flow_sequence_entry_mapping_end(parser *yaml_parser_t, event *yaml_event_t) bool { - token := peek_token(parser) - if token == nil { - return false - } - parser.state = yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE - *event = yaml_event_t{ - typ: yaml_MAPPING_END_EVENT, - start_mark: token.start_mark, - end_mark: token.start_mark, // [Go] Shouldn't this be end_mark? - } - return true -} - -// Parse the productions: -// flow_mapping ::= FLOW-MAPPING-START -// ****************** -// (flow_mapping_entry FLOW-ENTRY)* -// * ********** -// flow_mapping_entry? -// ****************** -// FLOW-MAPPING-END -// **************** -// flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -// * *** * -// -func yaml_parser_parse_flow_mapping_key(parser *yaml_parser_t, event *yaml_event_t, first bool) bool { - if first { - token := peek_token(parser) - parser.marks = append(parser.marks, token.start_mark) - skip_token(parser) - } - - token := peek_token(parser) - if token == nil { - return false - } - - if token.typ != yaml_FLOW_MAPPING_END_TOKEN { - if !first { - if token.typ == yaml_FLOW_ENTRY_TOKEN { - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - } else { - context_mark := parser.marks[len(parser.marks)-1] - parser.marks = parser.marks[:len(parser.marks)-1] - return yaml_parser_set_parser_error_context(parser, - "while parsing a flow mapping", context_mark, - "did not find expected ',' or '}'", token.start_mark) - } - } - - if token.typ == yaml_KEY_TOKEN { - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_VALUE_TOKEN && - token.typ != yaml_FLOW_ENTRY_TOKEN && - token.typ != yaml_FLOW_MAPPING_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_FLOW_MAPPING_VALUE_STATE) - return yaml_parser_parse_node(parser, event, false, false) - } else { - parser.state = yaml_PARSE_FLOW_MAPPING_VALUE_STATE - return yaml_parser_process_empty_scalar(parser, event, token.start_mark) - } - } else if token.typ != yaml_FLOW_MAPPING_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE) - return yaml_parser_parse_node(parser, event, false, false) - } - } - - parser.state = parser.states[len(parser.states)-1] - parser.states = parser.states[:len(parser.states)-1] - parser.marks = parser.marks[:len(parser.marks)-1] - *event = yaml_event_t{ - typ: yaml_MAPPING_END_EVENT, - start_mark: token.start_mark, - end_mark: token.end_mark, - } - skip_token(parser) - return true -} - -// Parse the productions: -// flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -// * ***** * -// -func yaml_parser_parse_flow_mapping_value(parser *yaml_parser_t, event *yaml_event_t, empty bool) bool { - token := peek_token(parser) - if token == nil { - return false - } - if empty { - parser.state = yaml_PARSE_FLOW_MAPPING_KEY_STATE - return yaml_parser_process_empty_scalar(parser, event, token.start_mark) - } - if token.typ == yaml_VALUE_TOKEN { - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - if token.typ != yaml_FLOW_ENTRY_TOKEN && token.typ != yaml_FLOW_MAPPING_END_TOKEN { - parser.states = append(parser.states, yaml_PARSE_FLOW_MAPPING_KEY_STATE) - return yaml_parser_parse_node(parser, event, false, false) - } - } - parser.state = yaml_PARSE_FLOW_MAPPING_KEY_STATE - return yaml_parser_process_empty_scalar(parser, event, token.start_mark) -} - -// Generate an empty scalar event. -func yaml_parser_process_empty_scalar(parser *yaml_parser_t, event *yaml_event_t, mark yaml_mark_t) bool { - *event = yaml_event_t{ - typ: yaml_SCALAR_EVENT, - start_mark: mark, - end_mark: mark, - value: nil, // Empty - implicit: true, - style: yaml_style_t(yaml_PLAIN_SCALAR_STYLE), - } - return true -} - -var default_tag_directives = []yaml_tag_directive_t{ - {[]byte("!"), []byte("!")}, - {[]byte("!!"), []byte("tag:yaml.org,2002:")}, -} - -// Parse directives. -func yaml_parser_process_directives(parser *yaml_parser_t, - version_directive_ref **yaml_version_directive_t, - tag_directives_ref *[]yaml_tag_directive_t) bool { - - var version_directive *yaml_version_directive_t - var tag_directives []yaml_tag_directive_t - - token := peek_token(parser) - if token == nil { - return false - } - - for token.typ == yaml_VERSION_DIRECTIVE_TOKEN || token.typ == yaml_TAG_DIRECTIVE_TOKEN { - if token.typ == yaml_VERSION_DIRECTIVE_TOKEN { - if version_directive != nil { - yaml_parser_set_parser_error(parser, - "found duplicate %YAML directive", token.start_mark) - return false - } - if token.major != 1 || token.minor != 1 { - yaml_parser_set_parser_error(parser, - "found incompatible YAML document", token.start_mark) - return false - } - version_directive = &yaml_version_directive_t{ - major: token.major, - minor: token.minor, - } - } else if token.typ == yaml_TAG_DIRECTIVE_TOKEN { - value := yaml_tag_directive_t{ - handle: token.value, - prefix: token.prefix, - } - if !yaml_parser_append_tag_directive(parser, value, false, token.start_mark) { - return false - } - tag_directives = append(tag_directives, value) - } - - skip_token(parser) - token = peek_token(parser) - if token == nil { - return false - } - } - - for i := range default_tag_directives { - if !yaml_parser_append_tag_directive(parser, default_tag_directives[i], true, token.start_mark) { - return false - } - } - - if version_directive_ref != nil { - *version_directive_ref = version_directive - } - if tag_directives_ref != nil { - *tag_directives_ref = tag_directives - } - return true -} - -// Append a tag directive to the directives stack. -func yaml_parser_append_tag_directive(parser *yaml_parser_t, value yaml_tag_directive_t, allow_duplicates bool, mark yaml_mark_t) bool { - for i := range parser.tag_directives { - if bytes.Equal(value.handle, parser.tag_directives[i].handle) { - if allow_duplicates { - return true - } - return yaml_parser_set_parser_error(parser, "found duplicate %TAG directive", mark) - } - } - - // [Go] I suspect the copy is unnecessary. This was likely done - // because there was no way to track ownership of the data. - value_copy := yaml_tag_directive_t{ - handle: make([]byte, len(value.handle)), - prefix: make([]byte, len(value.prefix)), - } - copy(value_copy.handle, value.handle) - copy(value_copy.prefix, value.prefix) - parser.tag_directives = append(parser.tag_directives, value_copy) - return true -} diff --git a/vendor/gopkg.in/yaml.v2/readerc.go b/vendor/gopkg.in/yaml.v2/readerc.go deleted file mode 100644 index 7c1f5fac3d..0000000000 --- a/vendor/gopkg.in/yaml.v2/readerc.go +++ /dev/null @@ -1,412 +0,0 @@ -package yaml - -import ( - "io" -) - -// Set the reader error and return 0. -func yaml_parser_set_reader_error(parser *yaml_parser_t, problem string, offset int, value int) bool { - parser.error = yaml_READER_ERROR - parser.problem = problem - parser.problem_offset = offset - parser.problem_value = value - return false -} - -// Byte order marks. -const ( - bom_UTF8 = "\xef\xbb\xbf" - bom_UTF16LE = "\xff\xfe" - bom_UTF16BE = "\xfe\xff" -) - -// Determine the input stream encoding by checking the BOM symbol. If no BOM is -// found, the UTF-8 encoding is assumed. Return 1 on success, 0 on failure. -func yaml_parser_determine_encoding(parser *yaml_parser_t) bool { - // Ensure that we had enough bytes in the raw buffer. - for !parser.eof && len(parser.raw_buffer)-parser.raw_buffer_pos < 3 { - if !yaml_parser_update_raw_buffer(parser) { - return false - } - } - - // Determine the encoding. - buf := parser.raw_buffer - pos := parser.raw_buffer_pos - avail := len(buf) - pos - if avail >= 2 && buf[pos] == bom_UTF16LE[0] && buf[pos+1] == bom_UTF16LE[1] { - parser.encoding = yaml_UTF16LE_ENCODING - parser.raw_buffer_pos += 2 - parser.offset += 2 - } else if avail >= 2 && buf[pos] == bom_UTF16BE[0] && buf[pos+1] == bom_UTF16BE[1] { - parser.encoding = yaml_UTF16BE_ENCODING - parser.raw_buffer_pos += 2 - parser.offset += 2 - } else if avail >= 3 && buf[pos] == bom_UTF8[0] && buf[pos+1] == bom_UTF8[1] && buf[pos+2] == bom_UTF8[2] { - parser.encoding = yaml_UTF8_ENCODING - parser.raw_buffer_pos += 3 - parser.offset += 3 - } else { - parser.encoding = yaml_UTF8_ENCODING - } - return true -} - -// Update the raw buffer. -func yaml_parser_update_raw_buffer(parser *yaml_parser_t) bool { - size_read := 0 - - // Return if the raw buffer is full. - if parser.raw_buffer_pos == 0 && len(parser.raw_buffer) == cap(parser.raw_buffer) { - return true - } - - // Return on EOF. - if parser.eof { - return true - } - - // Move the remaining bytes in the raw buffer to the beginning. - if parser.raw_buffer_pos > 0 && parser.raw_buffer_pos < len(parser.raw_buffer) { - copy(parser.raw_buffer, parser.raw_buffer[parser.raw_buffer_pos:]) - } - parser.raw_buffer = parser.raw_buffer[:len(parser.raw_buffer)-parser.raw_buffer_pos] - parser.raw_buffer_pos = 0 - - // Call the read handler to fill the buffer. - size_read, err := parser.read_handler(parser, parser.raw_buffer[len(parser.raw_buffer):cap(parser.raw_buffer)]) - parser.raw_buffer = parser.raw_buffer[:len(parser.raw_buffer)+size_read] - if err == io.EOF { - parser.eof = true - } else if err != nil { - return yaml_parser_set_reader_error(parser, "input error: "+err.Error(), parser.offset, -1) - } - return true -} - -// Ensure that the buffer contains at least `length` characters. -// Return true on success, false on failure. -// -// The length is supposed to be significantly less that the buffer size. -func yaml_parser_update_buffer(parser *yaml_parser_t, length int) bool { - if parser.read_handler == nil { - panic("read handler must be set") - } - - // [Go] This function was changed to guarantee the requested length size at EOF. - // The fact we need to do this is pretty awful, but the description above implies - // for that to be the case, and there are tests - - // If the EOF flag is set and the raw buffer is empty, do nothing. - if parser.eof && parser.raw_buffer_pos == len(parser.raw_buffer) { - // [Go] ACTUALLY! Read the documentation of this function above. - // This is just broken. To return true, we need to have the - // given length in the buffer. Not doing that means every single - // check that calls this function to make sure the buffer has a - // given length is Go) panicking; or C) accessing invalid memory. - //return true - } - - // Return if the buffer contains enough characters. - if parser.unread >= length { - return true - } - - // Determine the input encoding if it is not known yet. - if parser.encoding == yaml_ANY_ENCODING { - if !yaml_parser_determine_encoding(parser) { - return false - } - } - - // Move the unread characters to the beginning of the buffer. - buffer_len := len(parser.buffer) - if parser.buffer_pos > 0 && parser.buffer_pos < buffer_len { - copy(parser.buffer, parser.buffer[parser.buffer_pos:]) - buffer_len -= parser.buffer_pos - parser.buffer_pos = 0 - } else if parser.buffer_pos == buffer_len { - buffer_len = 0 - parser.buffer_pos = 0 - } - - // Open the whole buffer for writing, and cut it before returning. - parser.buffer = parser.buffer[:cap(parser.buffer)] - - // Fill the buffer until it has enough characters. - first := true - for parser.unread < length { - - // Fill the raw buffer if necessary. - if !first || parser.raw_buffer_pos == len(parser.raw_buffer) { - if !yaml_parser_update_raw_buffer(parser) { - parser.buffer = parser.buffer[:buffer_len] - return false - } - } - first = false - - // Decode the raw buffer. - inner: - for parser.raw_buffer_pos != len(parser.raw_buffer) { - var value rune - var width int - - raw_unread := len(parser.raw_buffer) - parser.raw_buffer_pos - - // Decode the next character. - switch parser.encoding { - case yaml_UTF8_ENCODING: - // Decode a UTF-8 character. Check RFC 3629 - // (http://www.ietf.org/rfc/rfc3629.txt) for more details. - // - // The following table (taken from the RFC) is used for - // decoding. - // - // Char. number range | UTF-8 octet sequence - // (hexadecimal) | (binary) - // --------------------+------------------------------------ - // 0000 0000-0000 007F | 0xxxxxxx - // 0000 0080-0000 07FF | 110xxxxx 10xxxxxx - // 0000 0800-0000 FFFF | 1110xxxx 10xxxxxx 10xxxxxx - // 0001 0000-0010 FFFF | 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - // - // Additionally, the characters in the range 0xD800-0xDFFF - // are prohibited as they are reserved for use with UTF-16 - // surrogate pairs. - - // Determine the length of the UTF-8 sequence. - octet := parser.raw_buffer[parser.raw_buffer_pos] - switch { - case octet&0x80 == 0x00: - width = 1 - case octet&0xE0 == 0xC0: - width = 2 - case octet&0xF0 == 0xE0: - width = 3 - case octet&0xF8 == 0xF0: - width = 4 - default: - // The leading octet is invalid. - return yaml_parser_set_reader_error(parser, - "invalid leading UTF-8 octet", - parser.offset, int(octet)) - } - - // Check if the raw buffer contains an incomplete character. - if width > raw_unread { - if parser.eof { - return yaml_parser_set_reader_error(parser, - "incomplete UTF-8 octet sequence", - parser.offset, -1) - } - break inner - } - - // Decode the leading octet. - switch { - case octet&0x80 == 0x00: - value = rune(octet & 0x7F) - case octet&0xE0 == 0xC0: - value = rune(octet & 0x1F) - case octet&0xF0 == 0xE0: - value = rune(octet & 0x0F) - case octet&0xF8 == 0xF0: - value = rune(octet & 0x07) - default: - value = 0 - } - - // Check and decode the trailing octets. - for k := 1; k < width; k++ { - octet = parser.raw_buffer[parser.raw_buffer_pos+k] - - // Check if the octet is valid. - if (octet & 0xC0) != 0x80 { - return yaml_parser_set_reader_error(parser, - "invalid trailing UTF-8 octet", - parser.offset+k, int(octet)) - } - - // Decode the octet. - value = (value << 6) + rune(octet&0x3F) - } - - // Check the length of the sequence against the value. - switch { - case width == 1: - case width == 2 && value >= 0x80: - case width == 3 && value >= 0x800: - case width == 4 && value >= 0x10000: - default: - return yaml_parser_set_reader_error(parser, - "invalid length of a UTF-8 sequence", - parser.offset, -1) - } - - // Check the range of the value. - if value >= 0xD800 && value <= 0xDFFF || value > 0x10FFFF { - return yaml_parser_set_reader_error(parser, - "invalid Unicode character", - parser.offset, int(value)) - } - - case yaml_UTF16LE_ENCODING, yaml_UTF16BE_ENCODING: - var low, high int - if parser.encoding == yaml_UTF16LE_ENCODING { - low, high = 0, 1 - } else { - low, high = 1, 0 - } - - // The UTF-16 encoding is not as simple as one might - // naively think. Check RFC 2781 - // (http://www.ietf.org/rfc/rfc2781.txt). - // - // Normally, two subsequent bytes describe a Unicode - // character. However a special technique (called a - // surrogate pair) is used for specifying character - // values larger than 0xFFFF. - // - // A surrogate pair consists of two pseudo-characters: - // high surrogate area (0xD800-0xDBFF) - // low surrogate area (0xDC00-0xDFFF) - // - // The following formulas are used for decoding - // and encoding characters using surrogate pairs: - // - // U = U' + 0x10000 (0x01 00 00 <= U <= 0x10 FF FF) - // U' = yyyyyyyyyyxxxxxxxxxx (0 <= U' <= 0x0F FF FF) - // W1 = 110110yyyyyyyyyy - // W2 = 110111xxxxxxxxxx - // - // where U is the character value, W1 is the high surrogate - // area, W2 is the low surrogate area. - - // Check for incomplete UTF-16 character. - if raw_unread < 2 { - if parser.eof { - return yaml_parser_set_reader_error(parser, - "incomplete UTF-16 character", - parser.offset, -1) - } - break inner - } - - // Get the character. - value = rune(parser.raw_buffer[parser.raw_buffer_pos+low]) + - (rune(parser.raw_buffer[parser.raw_buffer_pos+high]) << 8) - - // Check for unexpected low surrogate area. - if value&0xFC00 == 0xDC00 { - return yaml_parser_set_reader_error(parser, - "unexpected low surrogate area", - parser.offset, int(value)) - } - - // Check for a high surrogate area. - if value&0xFC00 == 0xD800 { - width = 4 - - // Check for incomplete surrogate pair. - if raw_unread < 4 { - if parser.eof { - return yaml_parser_set_reader_error(parser, - "incomplete UTF-16 surrogate pair", - parser.offset, -1) - } - break inner - } - - // Get the next character. - value2 := rune(parser.raw_buffer[parser.raw_buffer_pos+low+2]) + - (rune(parser.raw_buffer[parser.raw_buffer_pos+high+2]) << 8) - - // Check for a low surrogate area. - if value2&0xFC00 != 0xDC00 { - return yaml_parser_set_reader_error(parser, - "expected low surrogate area", - parser.offset+2, int(value2)) - } - - // Generate the value of the surrogate pair. - value = 0x10000 + ((value & 0x3FF) << 10) + (value2 & 0x3FF) - } else { - width = 2 - } - - default: - panic("impossible") - } - - // Check if the character is in the allowed range: - // #x9 | #xA | #xD | [#x20-#x7E] (8 bit) - // | #x85 | [#xA0-#xD7FF] | [#xE000-#xFFFD] (16 bit) - // | [#x10000-#x10FFFF] (32 bit) - switch { - case value == 0x09: - case value == 0x0A: - case value == 0x0D: - case value >= 0x20 && value <= 0x7E: - case value == 0x85: - case value >= 0xA0 && value <= 0xD7FF: - case value >= 0xE000 && value <= 0xFFFD: - case value >= 0x10000 && value <= 0x10FFFF: - default: - return yaml_parser_set_reader_error(parser, - "control characters are not allowed", - parser.offset, int(value)) - } - - // Move the raw pointers. - parser.raw_buffer_pos += width - parser.offset += width - - // Finally put the character into the buffer. - if value <= 0x7F { - // 0000 0000-0000 007F . 0xxxxxxx - parser.buffer[buffer_len+0] = byte(value) - buffer_len += 1 - } else if value <= 0x7FF { - // 0000 0080-0000 07FF . 110xxxxx 10xxxxxx - parser.buffer[buffer_len+0] = byte(0xC0 + (value >> 6)) - parser.buffer[buffer_len+1] = byte(0x80 + (value & 0x3F)) - buffer_len += 2 - } else if value <= 0xFFFF { - // 0000 0800-0000 FFFF . 1110xxxx 10xxxxxx 10xxxxxx - parser.buffer[buffer_len+0] = byte(0xE0 + (value >> 12)) - parser.buffer[buffer_len+1] = byte(0x80 + ((value >> 6) & 0x3F)) - parser.buffer[buffer_len+2] = byte(0x80 + (value & 0x3F)) - buffer_len += 3 - } else { - // 0001 0000-0010 FFFF . 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - parser.buffer[buffer_len+0] = byte(0xF0 + (value >> 18)) - parser.buffer[buffer_len+1] = byte(0x80 + ((value >> 12) & 0x3F)) - parser.buffer[buffer_len+2] = byte(0x80 + ((value >> 6) & 0x3F)) - parser.buffer[buffer_len+3] = byte(0x80 + (value & 0x3F)) - buffer_len += 4 - } - - parser.unread++ - } - - // On EOF, put NUL into the buffer and return. - if parser.eof { - parser.buffer[buffer_len] = 0 - buffer_len++ - parser.unread++ - break - } - } - // [Go] Read the documentation of this function above. To return true, - // we need to have the given length in the buffer. Not doing that means - // every single check that calls this function to make sure the buffer - // has a given length is Go) panicking; or C) accessing invalid memory. - // This happens here due to the EOF above breaking early. - for buffer_len < length { - parser.buffer[buffer_len] = 0 - buffer_len++ - } - parser.buffer = parser.buffer[:buffer_len] - return true -} diff --git a/vendor/gopkg.in/yaml.v2/resolve.go b/vendor/gopkg.in/yaml.v2/resolve.go deleted file mode 100644 index 6c151db6fb..0000000000 --- a/vendor/gopkg.in/yaml.v2/resolve.go +++ /dev/null @@ -1,258 +0,0 @@ -package yaml - -import ( - "encoding/base64" - "math" - "regexp" - "strconv" - "strings" - "time" -) - -type resolveMapItem struct { - value interface{} - tag string -} - -var resolveTable = make([]byte, 256) -var resolveMap = make(map[string]resolveMapItem) - -func init() { - t := resolveTable - t[int('+')] = 'S' // Sign - t[int('-')] = 'S' - for _, c := range "0123456789" { - t[int(c)] = 'D' // Digit - } - for _, c := range "yYnNtTfFoO~" { - t[int(c)] = 'M' // In map - } - t[int('.')] = '.' // Float (potentially in map) - - var resolveMapList = []struct { - v interface{} - tag string - l []string - }{ - {true, yaml_BOOL_TAG, []string{"y", "Y", "yes", "Yes", "YES"}}, - {true, yaml_BOOL_TAG, []string{"true", "True", "TRUE"}}, - {true, yaml_BOOL_TAG, []string{"on", "On", "ON"}}, - {false, yaml_BOOL_TAG, []string{"n", "N", "no", "No", "NO"}}, - {false, yaml_BOOL_TAG, []string{"false", "False", "FALSE"}}, - {false, yaml_BOOL_TAG, []string{"off", "Off", "OFF"}}, - {nil, yaml_NULL_TAG, []string{"", "~", "null", "Null", "NULL"}}, - {math.NaN(), yaml_FLOAT_TAG, []string{".nan", ".NaN", ".NAN"}}, - {math.Inf(+1), yaml_FLOAT_TAG, []string{".inf", ".Inf", ".INF"}}, - {math.Inf(+1), yaml_FLOAT_TAG, []string{"+.inf", "+.Inf", "+.INF"}}, - {math.Inf(-1), yaml_FLOAT_TAG, []string{"-.inf", "-.Inf", "-.INF"}}, - {"<<", yaml_MERGE_TAG, []string{"<<"}}, - } - - m := resolveMap - for _, item := range resolveMapList { - for _, s := range item.l { - m[s] = resolveMapItem{item.v, item.tag} - } - } -} - -const longTagPrefix = "tag:yaml.org,2002:" - -func shortTag(tag string) string { - // TODO This can easily be made faster and produce less garbage. - if strings.HasPrefix(tag, longTagPrefix) { - return "!!" + tag[len(longTagPrefix):] - } - return tag -} - -func longTag(tag string) string { - if strings.HasPrefix(tag, "!!") { - return longTagPrefix + tag[2:] - } - return tag -} - -func resolvableTag(tag string) bool { - switch tag { - case "", yaml_STR_TAG, yaml_BOOL_TAG, yaml_INT_TAG, yaml_FLOAT_TAG, yaml_NULL_TAG, yaml_TIMESTAMP_TAG: - return true - } - return false -} - -var yamlStyleFloat = regexp.MustCompile(`^[-+]?[0-9]*\.?[0-9]+([eE][-+][0-9]+)?$`) - -func resolve(tag string, in string) (rtag string, out interface{}) { - if !resolvableTag(tag) { - return tag, in - } - - defer func() { - switch tag { - case "", rtag, yaml_STR_TAG, yaml_BINARY_TAG: - return - case yaml_FLOAT_TAG: - if rtag == yaml_INT_TAG { - switch v := out.(type) { - case int64: - rtag = yaml_FLOAT_TAG - out = float64(v) - return - case int: - rtag = yaml_FLOAT_TAG - out = float64(v) - return - } - } - } - failf("cannot decode %s `%s` as a %s", shortTag(rtag), in, shortTag(tag)) - }() - - // Any data is accepted as a !!str or !!binary. - // Otherwise, the prefix is enough of a hint about what it might be. - hint := byte('N') - if in != "" { - hint = resolveTable[in[0]] - } - if hint != 0 && tag != yaml_STR_TAG && tag != yaml_BINARY_TAG { - // Handle things we can lookup in a map. - if item, ok := resolveMap[in]; ok { - return item.tag, item.value - } - - // Base 60 floats are a bad idea, were dropped in YAML 1.2, and - // are purposefully unsupported here. They're still quoted on - // the way out for compatibility with other parser, though. - - switch hint { - case 'M': - // We've already checked the map above. - - case '.': - // Not in the map, so maybe a normal float. - floatv, err := strconv.ParseFloat(in, 64) - if err == nil { - return yaml_FLOAT_TAG, floatv - } - - case 'D', 'S': - // Int, float, or timestamp. - // Only try values as a timestamp if the value is unquoted or there's an explicit - // !!timestamp tag. - if tag == "" || tag == yaml_TIMESTAMP_TAG { - t, ok := parseTimestamp(in) - if ok { - return yaml_TIMESTAMP_TAG, t - } - } - - plain := strings.Replace(in, "_", "", -1) - intv, err := strconv.ParseInt(plain, 0, 64) - if err == nil { - if intv == int64(int(intv)) { - return yaml_INT_TAG, int(intv) - } else { - return yaml_INT_TAG, intv - } - } - uintv, err := strconv.ParseUint(plain, 0, 64) - if err == nil { - return yaml_INT_TAG, uintv - } - if yamlStyleFloat.MatchString(plain) { - floatv, err := strconv.ParseFloat(plain, 64) - if err == nil { - return yaml_FLOAT_TAG, floatv - } - } - if strings.HasPrefix(plain, "0b") { - intv, err := strconv.ParseInt(plain[2:], 2, 64) - if err == nil { - if intv == int64(int(intv)) { - return yaml_INT_TAG, int(intv) - } else { - return yaml_INT_TAG, intv - } - } - uintv, err := strconv.ParseUint(plain[2:], 2, 64) - if err == nil { - return yaml_INT_TAG, uintv - } - } else if strings.HasPrefix(plain, "-0b") { - intv, err := strconv.ParseInt("-" + plain[3:], 2, 64) - if err == nil { - if true || intv == int64(int(intv)) { - return yaml_INT_TAG, int(intv) - } else { - return yaml_INT_TAG, intv - } - } - } - default: - panic("resolveTable item not yet handled: " + string(rune(hint)) + " (with " + in + ")") - } - } - return yaml_STR_TAG, in -} - -// encodeBase64 encodes s as base64 that is broken up into multiple lines -// as appropriate for the resulting length. -func encodeBase64(s string) string { - const lineLen = 70 - encLen := base64.StdEncoding.EncodedLen(len(s)) - lines := encLen/lineLen + 1 - buf := make([]byte, encLen*2+lines) - in := buf[0:encLen] - out := buf[encLen:] - base64.StdEncoding.Encode(in, []byte(s)) - k := 0 - for i := 0; i < len(in); i += lineLen { - j := i + lineLen - if j > len(in) { - j = len(in) - } - k += copy(out[k:], in[i:j]) - if lines > 1 { - out[k] = '\n' - k++ - } - } - return string(out[:k]) -} - -// This is a subset of the formats allowed by the regular expression -// defined at http://yaml.org/type/timestamp.html. -var allowedTimestampFormats = []string{ - "2006-1-2T15:4:5.999999999Z07:00", // RCF3339Nano with short date fields. - "2006-1-2t15:4:5.999999999Z07:00", // RFC3339Nano with short date fields and lower-case "t". - "2006-1-2 15:4:5.999999999", // space separated with no time zone - "2006-1-2", // date only - // Notable exception: time.Parse cannot handle: "2001-12-14 21:59:43.10 -5" - // from the set of examples. -} - -// parseTimestamp parses s as a timestamp string and -// returns the timestamp and reports whether it succeeded. -// Timestamp formats are defined at http://yaml.org/type/timestamp.html -func parseTimestamp(s string) (time.Time, bool) { - // TODO write code to check all the formats supported by - // http://yaml.org/type/timestamp.html instead of using time.Parse. - - // Quick check: all date formats start with YYYY-. - i := 0 - for ; i < len(s); i++ { - if c := s[i]; c < '0' || c > '9' { - break - } - } - if i != 4 || i == len(s) || s[i] != '-' { - return time.Time{}, false - } - for _, format := range allowedTimestampFormats { - if t, err := time.Parse(format, s); err == nil { - return t, true - } - } - return time.Time{}, false -} diff --git a/vendor/gopkg.in/yaml.v2/scannerc.go b/vendor/gopkg.in/yaml.v2/scannerc.go deleted file mode 100644 index 077fd1dd2d..0000000000 --- a/vendor/gopkg.in/yaml.v2/scannerc.go +++ /dev/null @@ -1,2696 +0,0 @@ -package yaml - -import ( - "bytes" - "fmt" -) - -// Introduction -// ************ -// -// The following notes assume that you are familiar with the YAML specification -// (http://yaml.org/spec/1.2/spec.html). We mostly follow it, although in -// some cases we are less restrictive that it requires. -// -// The process of transforming a YAML stream into a sequence of events is -// divided on two steps: Scanning and Parsing. -// -// The Scanner transforms the input stream into a sequence of tokens, while the -// parser transform the sequence of tokens produced by the Scanner into a -// sequence of parsing events. -// -// The Scanner is rather clever and complicated. The Parser, on the contrary, -// is a straightforward implementation of a recursive-descendant parser (or, -// LL(1) parser, as it is usually called). -// -// Actually there are two issues of Scanning that might be called "clever", the -// rest is quite straightforward. The issues are "block collection start" and -// "simple keys". Both issues are explained below in details. -// -// Here the Scanning step is explained and implemented. We start with the list -// of all the tokens produced by the Scanner together with short descriptions. -// -// Now, tokens: -// -// STREAM-START(encoding) # The stream start. -// STREAM-END # The stream end. -// VERSION-DIRECTIVE(major,minor) # The '%YAML' directive. -// TAG-DIRECTIVE(handle,prefix) # The '%TAG' directive. -// DOCUMENT-START # '---' -// DOCUMENT-END # '...' -// BLOCK-SEQUENCE-START # Indentation increase denoting a block -// BLOCK-MAPPING-START # sequence or a block mapping. -// BLOCK-END # Indentation decrease. -// FLOW-SEQUENCE-START # '[' -// FLOW-SEQUENCE-END # ']' -// BLOCK-SEQUENCE-START # '{' -// BLOCK-SEQUENCE-END # '}' -// BLOCK-ENTRY # '-' -// FLOW-ENTRY # ',' -// KEY # '?' or nothing (simple keys). -// VALUE # ':' -// ALIAS(anchor) # '*anchor' -// ANCHOR(anchor) # '&anchor' -// TAG(handle,suffix) # '!handle!suffix' -// SCALAR(value,style) # A scalar. -// -// The following two tokens are "virtual" tokens denoting the beginning and the -// end of the stream: -// -// STREAM-START(encoding) -// STREAM-END -// -// We pass the information about the input stream encoding with the -// STREAM-START token. -// -// The next two tokens are responsible for tags: -// -// VERSION-DIRECTIVE(major,minor) -// TAG-DIRECTIVE(handle,prefix) -// -// Example: -// -// %YAML 1.1 -// %TAG ! !foo -// %TAG !yaml! tag:yaml.org,2002: -// --- -// -// The correspoding sequence of tokens: -// -// STREAM-START(utf-8) -// VERSION-DIRECTIVE(1,1) -// TAG-DIRECTIVE("!","!foo") -// TAG-DIRECTIVE("!yaml","tag:yaml.org,2002:") -// DOCUMENT-START -// STREAM-END -// -// Note that the VERSION-DIRECTIVE and TAG-DIRECTIVE tokens occupy a whole -// line. -// -// The document start and end indicators are represented by: -// -// DOCUMENT-START -// DOCUMENT-END -// -// Note that if a YAML stream contains an implicit document (without '---' -// and '...' indicators), no DOCUMENT-START and DOCUMENT-END tokens will be -// produced. -// -// In the following examples, we present whole documents together with the -// produced tokens. -// -// 1. An implicit document: -// -// 'a scalar' -// -// Tokens: -// -// STREAM-START(utf-8) -// SCALAR("a scalar",single-quoted) -// STREAM-END -// -// 2. An explicit document: -// -// --- -// 'a scalar' -// ... -// -// Tokens: -// -// STREAM-START(utf-8) -// DOCUMENT-START -// SCALAR("a scalar",single-quoted) -// DOCUMENT-END -// STREAM-END -// -// 3. Several documents in a stream: -// -// 'a scalar' -// --- -// 'another scalar' -// --- -// 'yet another scalar' -// -// Tokens: -// -// STREAM-START(utf-8) -// SCALAR("a scalar",single-quoted) -// DOCUMENT-START -// SCALAR("another scalar",single-quoted) -// DOCUMENT-START -// SCALAR("yet another scalar",single-quoted) -// STREAM-END -// -// We have already introduced the SCALAR token above. The following tokens are -// used to describe aliases, anchors, tag, and scalars: -// -// ALIAS(anchor) -// ANCHOR(anchor) -// TAG(handle,suffix) -// SCALAR(value,style) -// -// The following series of examples illustrate the usage of these tokens: -// -// 1. A recursive sequence: -// -// &A [ *A ] -// -// Tokens: -// -// STREAM-START(utf-8) -// ANCHOR("A") -// FLOW-SEQUENCE-START -// ALIAS("A") -// FLOW-SEQUENCE-END -// STREAM-END -// -// 2. A tagged scalar: -// -// !!float "3.14" # A good approximation. -// -// Tokens: -// -// STREAM-START(utf-8) -// TAG("!!","float") -// SCALAR("3.14",double-quoted) -// STREAM-END -// -// 3. Various scalar styles: -// -// --- # Implicit empty plain scalars do not produce tokens. -// --- a plain scalar -// --- 'a single-quoted scalar' -// --- "a double-quoted scalar" -// --- |- -// a literal scalar -// --- >- -// a folded -// scalar -// -// Tokens: -// -// STREAM-START(utf-8) -// DOCUMENT-START -// DOCUMENT-START -// SCALAR("a plain scalar",plain) -// DOCUMENT-START -// SCALAR("a single-quoted scalar",single-quoted) -// DOCUMENT-START -// SCALAR("a double-quoted scalar",double-quoted) -// DOCUMENT-START -// SCALAR("a literal scalar",literal) -// DOCUMENT-START -// SCALAR("a folded scalar",folded) -// STREAM-END -// -// Now it's time to review collection-related tokens. We will start with -// flow collections: -// -// FLOW-SEQUENCE-START -// FLOW-SEQUENCE-END -// FLOW-MAPPING-START -// FLOW-MAPPING-END -// FLOW-ENTRY -// KEY -// VALUE -// -// The tokens FLOW-SEQUENCE-START, FLOW-SEQUENCE-END, FLOW-MAPPING-START, and -// FLOW-MAPPING-END represent the indicators '[', ']', '{', and '}' -// correspondingly. FLOW-ENTRY represent the ',' indicator. Finally the -// indicators '?' and ':', which are used for denoting mapping keys and values, -// are represented by the KEY and VALUE tokens. -// -// The following examples show flow collections: -// -// 1. A flow sequence: -// -// [item 1, item 2, item 3] -// -// Tokens: -// -// STREAM-START(utf-8) -// FLOW-SEQUENCE-START -// SCALAR("item 1",plain) -// FLOW-ENTRY -// SCALAR("item 2",plain) -// FLOW-ENTRY -// SCALAR("item 3",plain) -// FLOW-SEQUENCE-END -// STREAM-END -// -// 2. A flow mapping: -// -// { -// a simple key: a value, # Note that the KEY token is produced. -// ? a complex key: another value, -// } -// -// Tokens: -// -// STREAM-START(utf-8) -// FLOW-MAPPING-START -// KEY -// SCALAR("a simple key",plain) -// VALUE -// SCALAR("a value",plain) -// FLOW-ENTRY -// KEY -// SCALAR("a complex key",plain) -// VALUE -// SCALAR("another value",plain) -// FLOW-ENTRY -// FLOW-MAPPING-END -// STREAM-END -// -// A simple key is a key which is not denoted by the '?' indicator. Note that -// the Scanner still produce the KEY token whenever it encounters a simple key. -// -// For scanning block collections, the following tokens are used (note that we -// repeat KEY and VALUE here): -// -// BLOCK-SEQUENCE-START -// BLOCK-MAPPING-START -// BLOCK-END -// BLOCK-ENTRY -// KEY -// VALUE -// -// The tokens BLOCK-SEQUENCE-START and BLOCK-MAPPING-START denote indentation -// increase that precedes a block collection (cf. the INDENT token in Python). -// The token BLOCK-END denote indentation decrease that ends a block collection -// (cf. the DEDENT token in Python). However YAML has some syntax pecularities -// that makes detections of these tokens more complex. -// -// The tokens BLOCK-ENTRY, KEY, and VALUE are used to represent the indicators -// '-', '?', and ':' correspondingly. -// -// The following examples show how the tokens BLOCK-SEQUENCE-START, -// BLOCK-MAPPING-START, and BLOCK-END are emitted by the Scanner: -// -// 1. Block sequences: -// -// - item 1 -// - item 2 -// - -// - item 3.1 -// - item 3.2 -// - -// key 1: value 1 -// key 2: value 2 -// -// Tokens: -// -// STREAM-START(utf-8) -// BLOCK-SEQUENCE-START -// BLOCK-ENTRY -// SCALAR("item 1",plain) -// BLOCK-ENTRY -// SCALAR("item 2",plain) -// BLOCK-ENTRY -// BLOCK-SEQUENCE-START -// BLOCK-ENTRY -// SCALAR("item 3.1",plain) -// BLOCK-ENTRY -// SCALAR("item 3.2",plain) -// BLOCK-END -// BLOCK-ENTRY -// BLOCK-MAPPING-START -// KEY -// SCALAR("key 1",plain) -// VALUE -// SCALAR("value 1",plain) -// KEY -// SCALAR("key 2",plain) -// VALUE -// SCALAR("value 2",plain) -// BLOCK-END -// BLOCK-END -// STREAM-END -// -// 2. Block mappings: -// -// a simple key: a value # The KEY token is produced here. -// ? a complex key -// : another value -// a mapping: -// key 1: value 1 -// key 2: value 2 -// a sequence: -// - item 1 -// - item 2 -// -// Tokens: -// -// STREAM-START(utf-8) -// BLOCK-MAPPING-START -// KEY -// SCALAR("a simple key",plain) -// VALUE -// SCALAR("a value",plain) -// KEY -// SCALAR("a complex key",plain) -// VALUE -// SCALAR("another value",plain) -// KEY -// SCALAR("a mapping",plain) -// BLOCK-MAPPING-START -// KEY -// SCALAR("key 1",plain) -// VALUE -// SCALAR("value 1",plain) -// KEY -// SCALAR("key 2",plain) -// VALUE -// SCALAR("value 2",plain) -// BLOCK-END -// KEY -// SCALAR("a sequence",plain) -// VALUE -// BLOCK-SEQUENCE-START -// BLOCK-ENTRY -// SCALAR("item 1",plain) -// BLOCK-ENTRY -// SCALAR("item 2",plain) -// BLOCK-END -// BLOCK-END -// STREAM-END -// -// YAML does not always require to start a new block collection from a new -// line. If the current line contains only '-', '?', and ':' indicators, a new -// block collection may start at the current line. The following examples -// illustrate this case: -// -// 1. Collections in a sequence: -// -// - - item 1 -// - item 2 -// - key 1: value 1 -// key 2: value 2 -// - ? complex key -// : complex value -// -// Tokens: -// -// STREAM-START(utf-8) -// BLOCK-SEQUENCE-START -// BLOCK-ENTRY -// BLOCK-SEQUENCE-START -// BLOCK-ENTRY -// SCALAR("item 1",plain) -// BLOCK-ENTRY -// SCALAR("item 2",plain) -// BLOCK-END -// BLOCK-ENTRY -// BLOCK-MAPPING-START -// KEY -// SCALAR("key 1",plain) -// VALUE -// SCALAR("value 1",plain) -// KEY -// SCALAR("key 2",plain) -// VALUE -// SCALAR("value 2",plain) -// BLOCK-END -// BLOCK-ENTRY -// BLOCK-MAPPING-START -// KEY -// SCALAR("complex key") -// VALUE -// SCALAR("complex value") -// BLOCK-END -// BLOCK-END -// STREAM-END -// -// 2. Collections in a mapping: -// -// ? a sequence -// : - item 1 -// - item 2 -// ? a mapping -// : key 1: value 1 -// key 2: value 2 -// -// Tokens: -// -// STREAM-START(utf-8) -// BLOCK-MAPPING-START -// KEY -// SCALAR("a sequence",plain) -// VALUE -// BLOCK-SEQUENCE-START -// BLOCK-ENTRY -// SCALAR("item 1",plain) -// BLOCK-ENTRY -// SCALAR("item 2",plain) -// BLOCK-END -// KEY -// SCALAR("a mapping",plain) -// VALUE -// BLOCK-MAPPING-START -// KEY -// SCALAR("key 1",plain) -// VALUE -// SCALAR("value 1",plain) -// KEY -// SCALAR("key 2",plain) -// VALUE -// SCALAR("value 2",plain) -// BLOCK-END -// BLOCK-END -// STREAM-END -// -// YAML also permits non-indented sequences if they are included into a block -// mapping. In this case, the token BLOCK-SEQUENCE-START is not produced: -// -// key: -// - item 1 # BLOCK-SEQUENCE-START is NOT produced here. -// - item 2 -// -// Tokens: -// -// STREAM-START(utf-8) -// BLOCK-MAPPING-START -// KEY -// SCALAR("key",plain) -// VALUE -// BLOCK-ENTRY -// SCALAR("item 1",plain) -// BLOCK-ENTRY -// SCALAR("item 2",plain) -// BLOCK-END -// - -// Ensure that the buffer contains the required number of characters. -// Return true on success, false on failure (reader error or memory error). -func cache(parser *yaml_parser_t, length int) bool { - // [Go] This was inlined: !cache(A, B) -> unread < B && !update(A, B) - return parser.unread >= length || yaml_parser_update_buffer(parser, length) -} - -// Advance the buffer pointer. -func skip(parser *yaml_parser_t) { - parser.mark.index++ - parser.mark.column++ - parser.unread-- - parser.buffer_pos += width(parser.buffer[parser.buffer_pos]) -} - -func skip_line(parser *yaml_parser_t) { - if is_crlf(parser.buffer, parser.buffer_pos) { - parser.mark.index += 2 - parser.mark.column = 0 - parser.mark.line++ - parser.unread -= 2 - parser.buffer_pos += 2 - } else if is_break(parser.buffer, parser.buffer_pos) { - parser.mark.index++ - parser.mark.column = 0 - parser.mark.line++ - parser.unread-- - parser.buffer_pos += width(parser.buffer[parser.buffer_pos]) - } -} - -// Copy a character to a string buffer and advance pointers. -func read(parser *yaml_parser_t, s []byte) []byte { - w := width(parser.buffer[parser.buffer_pos]) - if w == 0 { - panic("invalid character sequence") - } - if len(s) == 0 { - s = make([]byte, 0, 32) - } - if w == 1 && len(s)+w <= cap(s) { - s = s[:len(s)+1] - s[len(s)-1] = parser.buffer[parser.buffer_pos] - parser.buffer_pos++ - } else { - s = append(s, parser.buffer[parser.buffer_pos:parser.buffer_pos+w]...) - parser.buffer_pos += w - } - parser.mark.index++ - parser.mark.column++ - parser.unread-- - return s -} - -// Copy a line break character to a string buffer and advance pointers. -func read_line(parser *yaml_parser_t, s []byte) []byte { - buf := parser.buffer - pos := parser.buffer_pos - switch { - case buf[pos] == '\r' && buf[pos+1] == '\n': - // CR LF . LF - s = append(s, '\n') - parser.buffer_pos += 2 - parser.mark.index++ - parser.unread-- - case buf[pos] == '\r' || buf[pos] == '\n': - // CR|LF . LF - s = append(s, '\n') - parser.buffer_pos += 1 - case buf[pos] == '\xC2' && buf[pos+1] == '\x85': - // NEL . LF - s = append(s, '\n') - parser.buffer_pos += 2 - case buf[pos] == '\xE2' && buf[pos+1] == '\x80' && (buf[pos+2] == '\xA8' || buf[pos+2] == '\xA9'): - // LS|PS . LS|PS - s = append(s, buf[parser.buffer_pos:pos+3]...) - parser.buffer_pos += 3 - default: - return s - } - parser.mark.index++ - parser.mark.column = 0 - parser.mark.line++ - parser.unread-- - return s -} - -// Get the next token. -func yaml_parser_scan(parser *yaml_parser_t, token *yaml_token_t) bool { - // Erase the token object. - *token = yaml_token_t{} // [Go] Is this necessary? - - // No tokens after STREAM-END or error. - if parser.stream_end_produced || parser.error != yaml_NO_ERROR { - return true - } - - // Ensure that the tokens queue contains enough tokens. - if !parser.token_available { - if !yaml_parser_fetch_more_tokens(parser) { - return false - } - } - - // Fetch the next token from the queue. - *token = parser.tokens[parser.tokens_head] - parser.tokens_head++ - parser.tokens_parsed++ - parser.token_available = false - - if token.typ == yaml_STREAM_END_TOKEN { - parser.stream_end_produced = true - } - return true -} - -// Set the scanner error and return false. -func yaml_parser_set_scanner_error(parser *yaml_parser_t, context string, context_mark yaml_mark_t, problem string) bool { - parser.error = yaml_SCANNER_ERROR - parser.context = context - parser.context_mark = context_mark - parser.problem = problem - parser.problem_mark = parser.mark - return false -} - -func yaml_parser_set_scanner_tag_error(parser *yaml_parser_t, directive bool, context_mark yaml_mark_t, problem string) bool { - context := "while parsing a tag" - if directive { - context = "while parsing a %TAG directive" - } - return yaml_parser_set_scanner_error(parser, context, context_mark, problem) -} - -func trace(args ...interface{}) func() { - pargs := append([]interface{}{"+++"}, args...) - fmt.Println(pargs...) - pargs = append([]interface{}{"---"}, args...) - return func() { fmt.Println(pargs...) } -} - -// Ensure that the tokens queue contains at least one token which can be -// returned to the Parser. -func yaml_parser_fetch_more_tokens(parser *yaml_parser_t) bool { - // While we need more tokens to fetch, do it. - for { - // Check if we really need to fetch more tokens. - need_more_tokens := false - - if parser.tokens_head == len(parser.tokens) { - // Queue is empty. - need_more_tokens = true - } else { - // Check if any potential simple key may occupy the head position. - if !yaml_parser_stale_simple_keys(parser) { - return false - } - - for i := range parser.simple_keys { - simple_key := &parser.simple_keys[i] - if simple_key.possible && simple_key.token_number == parser.tokens_parsed { - need_more_tokens = true - break - } - } - } - - // We are finished. - if !need_more_tokens { - break - } - // Fetch the next token. - if !yaml_parser_fetch_next_token(parser) { - return false - } - } - - parser.token_available = true - return true -} - -// The dispatcher for token fetchers. -func yaml_parser_fetch_next_token(parser *yaml_parser_t) bool { - // Ensure that the buffer is initialized. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - // Check if we just started scanning. Fetch STREAM-START then. - if !parser.stream_start_produced { - return yaml_parser_fetch_stream_start(parser) - } - - // Eat whitespaces and comments until we reach the next token. - if !yaml_parser_scan_to_next_token(parser) { - return false - } - - // Remove obsolete potential simple keys. - if !yaml_parser_stale_simple_keys(parser) { - return false - } - - // Check the indentation level against the current column. - if !yaml_parser_unroll_indent(parser, parser.mark.column) { - return false - } - - // Ensure that the buffer contains at least 4 characters. 4 is the length - // of the longest indicators ('--- ' and '... '). - if parser.unread < 4 && !yaml_parser_update_buffer(parser, 4) { - return false - } - - // Is it the end of the stream? - if is_z(parser.buffer, parser.buffer_pos) { - return yaml_parser_fetch_stream_end(parser) - } - - // Is it a directive? - if parser.mark.column == 0 && parser.buffer[parser.buffer_pos] == '%' { - return yaml_parser_fetch_directive(parser) - } - - buf := parser.buffer - pos := parser.buffer_pos - - // Is it the document start indicator? - if parser.mark.column == 0 && buf[pos] == '-' && buf[pos+1] == '-' && buf[pos+2] == '-' && is_blankz(buf, pos+3) { - return yaml_parser_fetch_document_indicator(parser, yaml_DOCUMENT_START_TOKEN) - } - - // Is it the document end indicator? - if parser.mark.column == 0 && buf[pos] == '.' && buf[pos+1] == '.' && buf[pos+2] == '.' && is_blankz(buf, pos+3) { - return yaml_parser_fetch_document_indicator(parser, yaml_DOCUMENT_END_TOKEN) - } - - // Is it the flow sequence start indicator? - if buf[pos] == '[' { - return yaml_parser_fetch_flow_collection_start(parser, yaml_FLOW_SEQUENCE_START_TOKEN) - } - - // Is it the flow mapping start indicator? - if parser.buffer[parser.buffer_pos] == '{' { - return yaml_parser_fetch_flow_collection_start(parser, yaml_FLOW_MAPPING_START_TOKEN) - } - - // Is it the flow sequence end indicator? - if parser.buffer[parser.buffer_pos] == ']' { - return yaml_parser_fetch_flow_collection_end(parser, - yaml_FLOW_SEQUENCE_END_TOKEN) - } - - // Is it the flow mapping end indicator? - if parser.buffer[parser.buffer_pos] == '}' { - return yaml_parser_fetch_flow_collection_end(parser, - yaml_FLOW_MAPPING_END_TOKEN) - } - - // Is it the flow entry indicator? - if parser.buffer[parser.buffer_pos] == ',' { - return yaml_parser_fetch_flow_entry(parser) - } - - // Is it the block entry indicator? - if parser.buffer[parser.buffer_pos] == '-' && is_blankz(parser.buffer, parser.buffer_pos+1) { - return yaml_parser_fetch_block_entry(parser) - } - - // Is it the key indicator? - if parser.buffer[parser.buffer_pos] == '?' && (parser.flow_level > 0 || is_blankz(parser.buffer, parser.buffer_pos+1)) { - return yaml_parser_fetch_key(parser) - } - - // Is it the value indicator? - if parser.buffer[parser.buffer_pos] == ':' && (parser.flow_level > 0 || is_blankz(parser.buffer, parser.buffer_pos+1)) { - return yaml_parser_fetch_value(parser) - } - - // Is it an alias? - if parser.buffer[parser.buffer_pos] == '*' { - return yaml_parser_fetch_anchor(parser, yaml_ALIAS_TOKEN) - } - - // Is it an anchor? - if parser.buffer[parser.buffer_pos] == '&' { - return yaml_parser_fetch_anchor(parser, yaml_ANCHOR_TOKEN) - } - - // Is it a tag? - if parser.buffer[parser.buffer_pos] == '!' { - return yaml_parser_fetch_tag(parser) - } - - // Is it a literal scalar? - if parser.buffer[parser.buffer_pos] == '|' && parser.flow_level == 0 { - return yaml_parser_fetch_block_scalar(parser, true) - } - - // Is it a folded scalar? - if parser.buffer[parser.buffer_pos] == '>' && parser.flow_level == 0 { - return yaml_parser_fetch_block_scalar(parser, false) - } - - // Is it a single-quoted scalar? - if parser.buffer[parser.buffer_pos] == '\'' { - return yaml_parser_fetch_flow_scalar(parser, true) - } - - // Is it a double-quoted scalar? - if parser.buffer[parser.buffer_pos] == '"' { - return yaml_parser_fetch_flow_scalar(parser, false) - } - - // Is it a plain scalar? - // - // A plain scalar may start with any non-blank characters except - // - // '-', '?', ':', ',', '[', ']', '{', '}', - // '#', '&', '*', '!', '|', '>', '\'', '\"', - // '%', '@', '`'. - // - // In the block context (and, for the '-' indicator, in the flow context - // too), it may also start with the characters - // - // '-', '?', ':' - // - // if it is followed by a non-space character. - // - // The last rule is more restrictive than the specification requires. - // [Go] Make this logic more reasonable. - //switch parser.buffer[parser.buffer_pos] { - //case '-', '?', ':', ',', '?', '-', ',', ':', ']', '[', '}', '{', '&', '#', '!', '*', '>', '|', '"', '\'', '@', '%', '-', '`': - //} - if !(is_blankz(parser.buffer, parser.buffer_pos) || parser.buffer[parser.buffer_pos] == '-' || - parser.buffer[parser.buffer_pos] == '?' || parser.buffer[parser.buffer_pos] == ':' || - parser.buffer[parser.buffer_pos] == ',' || parser.buffer[parser.buffer_pos] == '[' || - parser.buffer[parser.buffer_pos] == ']' || parser.buffer[parser.buffer_pos] == '{' || - parser.buffer[parser.buffer_pos] == '}' || parser.buffer[parser.buffer_pos] == '#' || - parser.buffer[parser.buffer_pos] == '&' || parser.buffer[parser.buffer_pos] == '*' || - parser.buffer[parser.buffer_pos] == '!' || parser.buffer[parser.buffer_pos] == '|' || - parser.buffer[parser.buffer_pos] == '>' || parser.buffer[parser.buffer_pos] == '\'' || - parser.buffer[parser.buffer_pos] == '"' || parser.buffer[parser.buffer_pos] == '%' || - parser.buffer[parser.buffer_pos] == '@' || parser.buffer[parser.buffer_pos] == '`') || - (parser.buffer[parser.buffer_pos] == '-' && !is_blank(parser.buffer, parser.buffer_pos+1)) || - (parser.flow_level == 0 && - (parser.buffer[parser.buffer_pos] == '?' || parser.buffer[parser.buffer_pos] == ':') && - !is_blankz(parser.buffer, parser.buffer_pos+1)) { - return yaml_parser_fetch_plain_scalar(parser) - } - - // If we don't determine the token type so far, it is an error. - return yaml_parser_set_scanner_error(parser, - "while scanning for the next token", parser.mark, - "found character that cannot start any token") -} - -// Check the list of potential simple keys and remove the positions that -// cannot contain simple keys anymore. -func yaml_parser_stale_simple_keys(parser *yaml_parser_t) bool { - // Check for a potential simple key for each flow level. - for i := range parser.simple_keys { - simple_key := &parser.simple_keys[i] - - // The specification requires that a simple key - // - // - is limited to a single line, - // - is shorter than 1024 characters. - if simple_key.possible && (simple_key.mark.line < parser.mark.line || simple_key.mark.index+1024 < parser.mark.index) { - - // Check if the potential simple key to be removed is required. - if simple_key.required { - return yaml_parser_set_scanner_error(parser, - "while scanning a simple key", simple_key.mark, - "could not find expected ':'") - } - simple_key.possible = false - } - } - return true -} - -// Check if a simple key may start at the current position and add it if -// needed. -func yaml_parser_save_simple_key(parser *yaml_parser_t) bool { - // A simple key is required at the current position if the scanner is in - // the block context and the current column coincides with the indentation - // level. - - required := parser.flow_level == 0 && parser.indent == parser.mark.column - - // - // If the current position may start a simple key, save it. - // - if parser.simple_key_allowed { - simple_key := yaml_simple_key_t{ - possible: true, - required: required, - token_number: parser.tokens_parsed + (len(parser.tokens) - parser.tokens_head), - } - simple_key.mark = parser.mark - - if !yaml_parser_remove_simple_key(parser) { - return false - } - parser.simple_keys[len(parser.simple_keys)-1] = simple_key - } - return true -} - -// Remove a potential simple key at the current flow level. -func yaml_parser_remove_simple_key(parser *yaml_parser_t) bool { - i := len(parser.simple_keys) - 1 - if parser.simple_keys[i].possible { - // If the key is required, it is an error. - if parser.simple_keys[i].required { - return yaml_parser_set_scanner_error(parser, - "while scanning a simple key", parser.simple_keys[i].mark, - "could not find expected ':'") - } - } - // Remove the key from the stack. - parser.simple_keys[i].possible = false - return true -} - -// Increase the flow level and resize the simple key list if needed. -func yaml_parser_increase_flow_level(parser *yaml_parser_t) bool { - // Reset the simple key on the next level. - parser.simple_keys = append(parser.simple_keys, yaml_simple_key_t{}) - - // Increase the flow level. - parser.flow_level++ - return true -} - -// Decrease the flow level. -func yaml_parser_decrease_flow_level(parser *yaml_parser_t) bool { - if parser.flow_level > 0 { - parser.flow_level-- - parser.simple_keys = parser.simple_keys[:len(parser.simple_keys)-1] - } - return true -} - -// Push the current indentation level to the stack and set the new level -// the current column is greater than the indentation level. In this case, -// append or insert the specified token into the token queue. -func yaml_parser_roll_indent(parser *yaml_parser_t, column, number int, typ yaml_token_type_t, mark yaml_mark_t) bool { - // In the flow context, do nothing. - if parser.flow_level > 0 { - return true - } - - if parser.indent < column { - // Push the current indentation level to the stack and set the new - // indentation level. - parser.indents = append(parser.indents, parser.indent) - parser.indent = column - - // Create a token and insert it into the queue. - token := yaml_token_t{ - typ: typ, - start_mark: mark, - end_mark: mark, - } - if number > -1 { - number -= parser.tokens_parsed - } - yaml_insert_token(parser, number, &token) - } - return true -} - -// Pop indentation levels from the indents stack until the current level -// becomes less or equal to the column. For each indentation level, append -// the BLOCK-END token. -func yaml_parser_unroll_indent(parser *yaml_parser_t, column int) bool { - // In the flow context, do nothing. - if parser.flow_level > 0 { - return true - } - - // Loop through the indentation levels in the stack. - for parser.indent > column { - // Create a token and append it to the queue. - token := yaml_token_t{ - typ: yaml_BLOCK_END_TOKEN, - start_mark: parser.mark, - end_mark: parser.mark, - } - yaml_insert_token(parser, -1, &token) - - // Pop the indentation level. - parser.indent = parser.indents[len(parser.indents)-1] - parser.indents = parser.indents[:len(parser.indents)-1] - } - return true -} - -// Initialize the scanner and produce the STREAM-START token. -func yaml_parser_fetch_stream_start(parser *yaml_parser_t) bool { - - // Set the initial indentation. - parser.indent = -1 - - // Initialize the simple key stack. - parser.simple_keys = append(parser.simple_keys, yaml_simple_key_t{}) - - // A simple key is allowed at the beginning of the stream. - parser.simple_key_allowed = true - - // We have started. - parser.stream_start_produced = true - - // Create the STREAM-START token and append it to the queue. - token := yaml_token_t{ - typ: yaml_STREAM_START_TOKEN, - start_mark: parser.mark, - end_mark: parser.mark, - encoding: parser.encoding, - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the STREAM-END token and shut down the scanner. -func yaml_parser_fetch_stream_end(parser *yaml_parser_t) bool { - - // Force new line. - if parser.mark.column != 0 { - parser.mark.column = 0 - parser.mark.line++ - } - - // Reset the indentation level. - if !yaml_parser_unroll_indent(parser, -1) { - return false - } - - // Reset simple keys. - if !yaml_parser_remove_simple_key(parser) { - return false - } - - parser.simple_key_allowed = false - - // Create the STREAM-END token and append it to the queue. - token := yaml_token_t{ - typ: yaml_STREAM_END_TOKEN, - start_mark: parser.mark, - end_mark: parser.mark, - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce a VERSION-DIRECTIVE or TAG-DIRECTIVE token. -func yaml_parser_fetch_directive(parser *yaml_parser_t) bool { - // Reset the indentation level. - if !yaml_parser_unroll_indent(parser, -1) { - return false - } - - // Reset simple keys. - if !yaml_parser_remove_simple_key(parser) { - return false - } - - parser.simple_key_allowed = false - - // Create the YAML-DIRECTIVE or TAG-DIRECTIVE token. - token := yaml_token_t{} - if !yaml_parser_scan_directive(parser, &token) { - return false - } - // Append the token to the queue. - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the DOCUMENT-START or DOCUMENT-END token. -func yaml_parser_fetch_document_indicator(parser *yaml_parser_t, typ yaml_token_type_t) bool { - // Reset the indentation level. - if !yaml_parser_unroll_indent(parser, -1) { - return false - } - - // Reset simple keys. - if !yaml_parser_remove_simple_key(parser) { - return false - } - - parser.simple_key_allowed = false - - // Consume the token. - start_mark := parser.mark - - skip(parser) - skip(parser) - skip(parser) - - end_mark := parser.mark - - // Create the DOCUMENT-START or DOCUMENT-END token. - token := yaml_token_t{ - typ: typ, - start_mark: start_mark, - end_mark: end_mark, - } - // Append the token to the queue. - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the FLOW-SEQUENCE-START or FLOW-MAPPING-START token. -func yaml_parser_fetch_flow_collection_start(parser *yaml_parser_t, typ yaml_token_type_t) bool { - // The indicators '[' and '{' may start a simple key. - if !yaml_parser_save_simple_key(parser) { - return false - } - - // Increase the flow level. - if !yaml_parser_increase_flow_level(parser) { - return false - } - - // A simple key may follow the indicators '[' and '{'. - parser.simple_key_allowed = true - - // Consume the token. - start_mark := parser.mark - skip(parser) - end_mark := parser.mark - - // Create the FLOW-SEQUENCE-START of FLOW-MAPPING-START token. - token := yaml_token_t{ - typ: typ, - start_mark: start_mark, - end_mark: end_mark, - } - // Append the token to the queue. - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the FLOW-SEQUENCE-END or FLOW-MAPPING-END token. -func yaml_parser_fetch_flow_collection_end(parser *yaml_parser_t, typ yaml_token_type_t) bool { - // Reset any potential simple key on the current flow level. - if !yaml_parser_remove_simple_key(parser) { - return false - } - - // Decrease the flow level. - if !yaml_parser_decrease_flow_level(parser) { - return false - } - - // No simple keys after the indicators ']' and '}'. - parser.simple_key_allowed = false - - // Consume the token. - - start_mark := parser.mark - skip(parser) - end_mark := parser.mark - - // Create the FLOW-SEQUENCE-END of FLOW-MAPPING-END token. - token := yaml_token_t{ - typ: typ, - start_mark: start_mark, - end_mark: end_mark, - } - // Append the token to the queue. - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the FLOW-ENTRY token. -func yaml_parser_fetch_flow_entry(parser *yaml_parser_t) bool { - // Reset any potential simple keys on the current flow level. - if !yaml_parser_remove_simple_key(parser) { - return false - } - - // Simple keys are allowed after ','. - parser.simple_key_allowed = true - - // Consume the token. - start_mark := parser.mark - skip(parser) - end_mark := parser.mark - - // Create the FLOW-ENTRY token and append it to the queue. - token := yaml_token_t{ - typ: yaml_FLOW_ENTRY_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the BLOCK-ENTRY token. -func yaml_parser_fetch_block_entry(parser *yaml_parser_t) bool { - // Check if the scanner is in the block context. - if parser.flow_level == 0 { - // Check if we are allowed to start a new entry. - if !parser.simple_key_allowed { - return yaml_parser_set_scanner_error(parser, "", parser.mark, - "block sequence entries are not allowed in this context") - } - // Add the BLOCK-SEQUENCE-START token if needed. - if !yaml_parser_roll_indent(parser, parser.mark.column, -1, yaml_BLOCK_SEQUENCE_START_TOKEN, parser.mark) { - return false - } - } else { - // It is an error for the '-' indicator to occur in the flow context, - // but we let the Parser detect and report about it because the Parser - // is able to point to the context. - } - - // Reset any potential simple keys on the current flow level. - if !yaml_parser_remove_simple_key(parser) { - return false - } - - // Simple keys are allowed after '-'. - parser.simple_key_allowed = true - - // Consume the token. - start_mark := parser.mark - skip(parser) - end_mark := parser.mark - - // Create the BLOCK-ENTRY token and append it to the queue. - token := yaml_token_t{ - typ: yaml_BLOCK_ENTRY_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the KEY token. -func yaml_parser_fetch_key(parser *yaml_parser_t) bool { - - // In the block context, additional checks are required. - if parser.flow_level == 0 { - // Check if we are allowed to start a new key (not nessesary simple). - if !parser.simple_key_allowed { - return yaml_parser_set_scanner_error(parser, "", parser.mark, - "mapping keys are not allowed in this context") - } - // Add the BLOCK-MAPPING-START token if needed. - if !yaml_parser_roll_indent(parser, parser.mark.column, -1, yaml_BLOCK_MAPPING_START_TOKEN, parser.mark) { - return false - } - } - - // Reset any potential simple keys on the current flow level. - if !yaml_parser_remove_simple_key(parser) { - return false - } - - // Simple keys are allowed after '?' in the block context. - parser.simple_key_allowed = parser.flow_level == 0 - - // Consume the token. - start_mark := parser.mark - skip(parser) - end_mark := parser.mark - - // Create the KEY token and append it to the queue. - token := yaml_token_t{ - typ: yaml_KEY_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the VALUE token. -func yaml_parser_fetch_value(parser *yaml_parser_t) bool { - - simple_key := &parser.simple_keys[len(parser.simple_keys)-1] - - // Have we found a simple key? - if simple_key.possible { - // Create the KEY token and insert it into the queue. - token := yaml_token_t{ - typ: yaml_KEY_TOKEN, - start_mark: simple_key.mark, - end_mark: simple_key.mark, - } - yaml_insert_token(parser, simple_key.token_number-parser.tokens_parsed, &token) - - // In the block context, we may need to add the BLOCK-MAPPING-START token. - if !yaml_parser_roll_indent(parser, simple_key.mark.column, - simple_key.token_number, - yaml_BLOCK_MAPPING_START_TOKEN, simple_key.mark) { - return false - } - - // Remove the simple key. - simple_key.possible = false - - // A simple key cannot follow another simple key. - parser.simple_key_allowed = false - - } else { - // The ':' indicator follows a complex key. - - // In the block context, extra checks are required. - if parser.flow_level == 0 { - - // Check if we are allowed to start a complex value. - if !parser.simple_key_allowed { - return yaml_parser_set_scanner_error(parser, "", parser.mark, - "mapping values are not allowed in this context") - } - - // Add the BLOCK-MAPPING-START token if needed. - if !yaml_parser_roll_indent(parser, parser.mark.column, -1, yaml_BLOCK_MAPPING_START_TOKEN, parser.mark) { - return false - } - } - - // Simple keys after ':' are allowed in the block context. - parser.simple_key_allowed = parser.flow_level == 0 - } - - // Consume the token. - start_mark := parser.mark - skip(parser) - end_mark := parser.mark - - // Create the VALUE token and append it to the queue. - token := yaml_token_t{ - typ: yaml_VALUE_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the ALIAS or ANCHOR token. -func yaml_parser_fetch_anchor(parser *yaml_parser_t, typ yaml_token_type_t) bool { - // An anchor or an alias could be a simple key. - if !yaml_parser_save_simple_key(parser) { - return false - } - - // A simple key cannot follow an anchor or an alias. - parser.simple_key_allowed = false - - // Create the ALIAS or ANCHOR token and append it to the queue. - var token yaml_token_t - if !yaml_parser_scan_anchor(parser, &token, typ) { - return false - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the TAG token. -func yaml_parser_fetch_tag(parser *yaml_parser_t) bool { - // A tag could be a simple key. - if !yaml_parser_save_simple_key(parser) { - return false - } - - // A simple key cannot follow a tag. - parser.simple_key_allowed = false - - // Create the TAG token and append it to the queue. - var token yaml_token_t - if !yaml_parser_scan_tag(parser, &token) { - return false - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the SCALAR(...,literal) or SCALAR(...,folded) tokens. -func yaml_parser_fetch_block_scalar(parser *yaml_parser_t, literal bool) bool { - // Remove any potential simple keys. - if !yaml_parser_remove_simple_key(parser) { - return false - } - - // A simple key may follow a block scalar. - parser.simple_key_allowed = true - - // Create the SCALAR token and append it to the queue. - var token yaml_token_t - if !yaml_parser_scan_block_scalar(parser, &token, literal) { - return false - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the SCALAR(...,single-quoted) or SCALAR(...,double-quoted) tokens. -func yaml_parser_fetch_flow_scalar(parser *yaml_parser_t, single bool) bool { - // A plain scalar could be a simple key. - if !yaml_parser_save_simple_key(parser) { - return false - } - - // A simple key cannot follow a flow scalar. - parser.simple_key_allowed = false - - // Create the SCALAR token and append it to the queue. - var token yaml_token_t - if !yaml_parser_scan_flow_scalar(parser, &token, single) { - return false - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Produce the SCALAR(...,plain) token. -func yaml_parser_fetch_plain_scalar(parser *yaml_parser_t) bool { - // A plain scalar could be a simple key. - if !yaml_parser_save_simple_key(parser) { - return false - } - - // A simple key cannot follow a flow scalar. - parser.simple_key_allowed = false - - // Create the SCALAR token and append it to the queue. - var token yaml_token_t - if !yaml_parser_scan_plain_scalar(parser, &token) { - return false - } - yaml_insert_token(parser, -1, &token) - return true -} - -// Eat whitespaces and comments until the next token is found. -func yaml_parser_scan_to_next_token(parser *yaml_parser_t) bool { - - // Until the next token is not found. - for { - // Allow the BOM mark to start a line. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - if parser.mark.column == 0 && is_bom(parser.buffer, parser.buffer_pos) { - skip(parser) - } - - // Eat whitespaces. - // Tabs are allowed: - // - in the flow context - // - in the block context, but not at the beginning of the line or - // after '-', '?', or ':' (complex value). - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - for parser.buffer[parser.buffer_pos] == ' ' || ((parser.flow_level > 0 || !parser.simple_key_allowed) && parser.buffer[parser.buffer_pos] == '\t') { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Eat a comment until a line break. - if parser.buffer[parser.buffer_pos] == '#' { - for !is_breakz(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - } - - // If it is a line break, eat it. - if is_break(parser.buffer, parser.buffer_pos) { - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - skip_line(parser) - - // In the block context, a new line may start a simple key. - if parser.flow_level == 0 { - parser.simple_key_allowed = true - } - } else { - break // We have found a token. - } - } - - return true -} - -// Scan a YAML-DIRECTIVE or TAG-DIRECTIVE token. -// -// Scope: -// %YAML 1.1 # a comment \n -// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -// %TAG !yaml! tag:yaml.org,2002: \n -// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -// -func yaml_parser_scan_directive(parser *yaml_parser_t, token *yaml_token_t) bool { - // Eat '%'. - start_mark := parser.mark - skip(parser) - - // Scan the directive name. - var name []byte - if !yaml_parser_scan_directive_name(parser, start_mark, &name) { - return false - } - - // Is it a YAML directive? - if bytes.Equal(name, []byte("YAML")) { - // Scan the VERSION directive value. - var major, minor int8 - if !yaml_parser_scan_version_directive_value(parser, start_mark, &major, &minor) { - return false - } - end_mark := parser.mark - - // Create a VERSION-DIRECTIVE token. - *token = yaml_token_t{ - typ: yaml_VERSION_DIRECTIVE_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - major: major, - minor: minor, - } - - // Is it a TAG directive? - } else if bytes.Equal(name, []byte("TAG")) { - // Scan the TAG directive value. - var handle, prefix []byte - if !yaml_parser_scan_tag_directive_value(parser, start_mark, &handle, &prefix) { - return false - } - end_mark := parser.mark - - // Create a TAG-DIRECTIVE token. - *token = yaml_token_t{ - typ: yaml_TAG_DIRECTIVE_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - value: handle, - prefix: prefix, - } - - // Unknown directive. - } else { - yaml_parser_set_scanner_error(parser, "while scanning a directive", - start_mark, "found unknown directive name") - return false - } - - // Eat the rest of the line including any comments. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - for is_blank(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - if parser.buffer[parser.buffer_pos] == '#' { - for !is_breakz(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - } - - // Check if we are at the end of the line. - if !is_breakz(parser.buffer, parser.buffer_pos) { - yaml_parser_set_scanner_error(parser, "while scanning a directive", - start_mark, "did not find expected comment or line break") - return false - } - - // Eat a line break. - if is_break(parser.buffer, parser.buffer_pos) { - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - skip_line(parser) - } - - return true -} - -// Scan the directive name. -// -// Scope: -// %YAML 1.1 # a comment \n -// ^^^^ -// %TAG !yaml! tag:yaml.org,2002: \n -// ^^^ -// -func yaml_parser_scan_directive_name(parser *yaml_parser_t, start_mark yaml_mark_t, name *[]byte) bool { - // Consume the directive name. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - var s []byte - for is_alpha(parser.buffer, parser.buffer_pos) { - s = read(parser, s) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Check if the name is empty. - if len(s) == 0 { - yaml_parser_set_scanner_error(parser, "while scanning a directive", - start_mark, "could not find expected directive name") - return false - } - - // Check for an blank character after the name. - if !is_blankz(parser.buffer, parser.buffer_pos) { - yaml_parser_set_scanner_error(parser, "while scanning a directive", - start_mark, "found unexpected non-alphabetical character") - return false - } - *name = s - return true -} - -// Scan the value of VERSION-DIRECTIVE. -// -// Scope: -// %YAML 1.1 # a comment \n -// ^^^^^^ -func yaml_parser_scan_version_directive_value(parser *yaml_parser_t, start_mark yaml_mark_t, major, minor *int8) bool { - // Eat whitespaces. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - for is_blank(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Consume the major version number. - if !yaml_parser_scan_version_directive_number(parser, start_mark, major) { - return false - } - - // Eat '.'. - if parser.buffer[parser.buffer_pos] != '.' { - return yaml_parser_set_scanner_error(parser, "while scanning a %YAML directive", - start_mark, "did not find expected digit or '.' character") - } - - skip(parser) - - // Consume the minor version number. - if !yaml_parser_scan_version_directive_number(parser, start_mark, minor) { - return false - } - return true -} - -const max_number_length = 2 - -// Scan the version number of VERSION-DIRECTIVE. -// -// Scope: -// %YAML 1.1 # a comment \n -// ^ -// %YAML 1.1 # a comment \n -// ^ -func yaml_parser_scan_version_directive_number(parser *yaml_parser_t, start_mark yaml_mark_t, number *int8) bool { - - // Repeat while the next character is digit. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - var value, length int8 - for is_digit(parser.buffer, parser.buffer_pos) { - // Check if the number is too long. - length++ - if length > max_number_length { - return yaml_parser_set_scanner_error(parser, "while scanning a %YAML directive", - start_mark, "found extremely long version number") - } - value = value*10 + int8(as_digit(parser.buffer, parser.buffer_pos)) - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Check if the number was present. - if length == 0 { - return yaml_parser_set_scanner_error(parser, "while scanning a %YAML directive", - start_mark, "did not find expected version number") - } - *number = value - return true -} - -// Scan the value of a TAG-DIRECTIVE token. -// -// Scope: -// %TAG !yaml! tag:yaml.org,2002: \n -// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -// -func yaml_parser_scan_tag_directive_value(parser *yaml_parser_t, start_mark yaml_mark_t, handle, prefix *[]byte) bool { - var handle_value, prefix_value []byte - - // Eat whitespaces. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - for is_blank(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Scan a handle. - if !yaml_parser_scan_tag_handle(parser, true, start_mark, &handle_value) { - return false - } - - // Expect a whitespace. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - if !is_blank(parser.buffer, parser.buffer_pos) { - yaml_parser_set_scanner_error(parser, "while scanning a %TAG directive", - start_mark, "did not find expected whitespace") - return false - } - - // Eat whitespaces. - for is_blank(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Scan a prefix. - if !yaml_parser_scan_tag_uri(parser, true, nil, start_mark, &prefix_value) { - return false - } - - // Expect a whitespace or line break. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - if !is_blankz(parser.buffer, parser.buffer_pos) { - yaml_parser_set_scanner_error(parser, "while scanning a %TAG directive", - start_mark, "did not find expected whitespace or line break") - return false - } - - *handle = handle_value - *prefix = prefix_value - return true -} - -func yaml_parser_scan_anchor(parser *yaml_parser_t, token *yaml_token_t, typ yaml_token_type_t) bool { - var s []byte - - // Eat the indicator character. - start_mark := parser.mark - skip(parser) - - // Consume the value. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - for is_alpha(parser.buffer, parser.buffer_pos) { - s = read(parser, s) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - end_mark := parser.mark - - /* - * Check if length of the anchor is greater than 0 and it is followed by - * a whitespace character or one of the indicators: - * - * '?', ':', ',', ']', '}', '%', '@', '`'. - */ - - if len(s) == 0 || - !(is_blankz(parser.buffer, parser.buffer_pos) || parser.buffer[parser.buffer_pos] == '?' || - parser.buffer[parser.buffer_pos] == ':' || parser.buffer[parser.buffer_pos] == ',' || - parser.buffer[parser.buffer_pos] == ']' || parser.buffer[parser.buffer_pos] == '}' || - parser.buffer[parser.buffer_pos] == '%' || parser.buffer[parser.buffer_pos] == '@' || - parser.buffer[parser.buffer_pos] == '`') { - context := "while scanning an alias" - if typ == yaml_ANCHOR_TOKEN { - context = "while scanning an anchor" - } - yaml_parser_set_scanner_error(parser, context, start_mark, - "did not find expected alphabetic or numeric character") - return false - } - - // Create a token. - *token = yaml_token_t{ - typ: typ, - start_mark: start_mark, - end_mark: end_mark, - value: s, - } - - return true -} - -/* - * Scan a TAG token. - */ - -func yaml_parser_scan_tag(parser *yaml_parser_t, token *yaml_token_t) bool { - var handle, suffix []byte - - start_mark := parser.mark - - // Check if the tag is in the canonical form. - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - - if parser.buffer[parser.buffer_pos+1] == '<' { - // Keep the handle as '' - - // Eat '!<' - skip(parser) - skip(parser) - - // Consume the tag value. - if !yaml_parser_scan_tag_uri(parser, false, nil, start_mark, &suffix) { - return false - } - - // Check for '>' and eat it. - if parser.buffer[parser.buffer_pos] != '>' { - yaml_parser_set_scanner_error(parser, "while scanning a tag", - start_mark, "did not find the expected '>'") - return false - } - - skip(parser) - } else { - // The tag has either the '!suffix' or the '!handle!suffix' form. - - // First, try to scan a handle. - if !yaml_parser_scan_tag_handle(parser, false, start_mark, &handle) { - return false - } - - // Check if it is, indeed, handle. - if handle[0] == '!' && len(handle) > 1 && handle[len(handle)-1] == '!' { - // Scan the suffix now. - if !yaml_parser_scan_tag_uri(parser, false, nil, start_mark, &suffix) { - return false - } - } else { - // It wasn't a handle after all. Scan the rest of the tag. - if !yaml_parser_scan_tag_uri(parser, false, handle, start_mark, &suffix) { - return false - } - - // Set the handle to '!'. - handle = []byte{'!'} - - // A special case: the '!' tag. Set the handle to '' and the - // suffix to '!'. - if len(suffix) == 0 { - handle, suffix = suffix, handle - } - } - } - - // Check the character which ends the tag. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - if !is_blankz(parser.buffer, parser.buffer_pos) { - yaml_parser_set_scanner_error(parser, "while scanning a tag", - start_mark, "did not find expected whitespace or line break") - return false - } - - end_mark := parser.mark - - // Create a token. - *token = yaml_token_t{ - typ: yaml_TAG_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - value: handle, - suffix: suffix, - } - return true -} - -// Scan a tag handle. -func yaml_parser_scan_tag_handle(parser *yaml_parser_t, directive bool, start_mark yaml_mark_t, handle *[]byte) bool { - // Check the initial '!' character. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - if parser.buffer[parser.buffer_pos] != '!' { - yaml_parser_set_scanner_tag_error(parser, directive, - start_mark, "did not find expected '!'") - return false - } - - var s []byte - - // Copy the '!' character. - s = read(parser, s) - - // Copy all subsequent alphabetical and numerical characters. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - for is_alpha(parser.buffer, parser.buffer_pos) { - s = read(parser, s) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Check if the trailing character is '!' and copy it. - if parser.buffer[parser.buffer_pos] == '!' { - s = read(parser, s) - } else { - // It's either the '!' tag or not really a tag handle. If it's a %TAG - // directive, it's an error. If it's a tag token, it must be a part of URI. - if directive && string(s) != "!" { - yaml_parser_set_scanner_tag_error(parser, directive, - start_mark, "did not find expected '!'") - return false - } - } - - *handle = s - return true -} - -// Scan a tag. -func yaml_parser_scan_tag_uri(parser *yaml_parser_t, directive bool, head []byte, start_mark yaml_mark_t, uri *[]byte) bool { - //size_t length = head ? strlen((char *)head) : 0 - var s []byte - hasTag := len(head) > 0 - - // Copy the head if needed. - // - // Note that we don't copy the leading '!' character. - if len(head) > 1 { - s = append(s, head[1:]...) - } - - // Scan the tag. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - // The set of characters that may appear in URI is as follows: - // - // '0'-'9', 'A'-'Z', 'a'-'z', '_', '-', ';', '/', '?', ':', '@', '&', - // '=', '+', '$', ',', '.', '!', '~', '*', '\'', '(', ')', '[', ']', - // '%'. - // [Go] Convert this into more reasonable logic. - for is_alpha(parser.buffer, parser.buffer_pos) || parser.buffer[parser.buffer_pos] == ';' || - parser.buffer[parser.buffer_pos] == '/' || parser.buffer[parser.buffer_pos] == '?' || - parser.buffer[parser.buffer_pos] == ':' || parser.buffer[parser.buffer_pos] == '@' || - parser.buffer[parser.buffer_pos] == '&' || parser.buffer[parser.buffer_pos] == '=' || - parser.buffer[parser.buffer_pos] == '+' || parser.buffer[parser.buffer_pos] == '$' || - parser.buffer[parser.buffer_pos] == ',' || parser.buffer[parser.buffer_pos] == '.' || - parser.buffer[parser.buffer_pos] == '!' || parser.buffer[parser.buffer_pos] == '~' || - parser.buffer[parser.buffer_pos] == '*' || parser.buffer[parser.buffer_pos] == '\'' || - parser.buffer[parser.buffer_pos] == '(' || parser.buffer[parser.buffer_pos] == ')' || - parser.buffer[parser.buffer_pos] == '[' || parser.buffer[parser.buffer_pos] == ']' || - parser.buffer[parser.buffer_pos] == '%' { - // Check if it is a URI-escape sequence. - if parser.buffer[parser.buffer_pos] == '%' { - if !yaml_parser_scan_uri_escapes(parser, directive, start_mark, &s) { - return false - } - } else { - s = read(parser, s) - } - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - hasTag = true - } - - if !hasTag { - yaml_parser_set_scanner_tag_error(parser, directive, - start_mark, "did not find expected tag URI") - return false - } - *uri = s - return true -} - -// Decode an URI-escape sequence corresponding to a single UTF-8 character. -func yaml_parser_scan_uri_escapes(parser *yaml_parser_t, directive bool, start_mark yaml_mark_t, s *[]byte) bool { - - // Decode the required number of characters. - w := 1024 - for w > 0 { - // Check for a URI-escaped octet. - if parser.unread < 3 && !yaml_parser_update_buffer(parser, 3) { - return false - } - - if !(parser.buffer[parser.buffer_pos] == '%' && - is_hex(parser.buffer, parser.buffer_pos+1) && - is_hex(parser.buffer, parser.buffer_pos+2)) { - return yaml_parser_set_scanner_tag_error(parser, directive, - start_mark, "did not find URI escaped octet") - } - - // Get the octet. - octet := byte((as_hex(parser.buffer, parser.buffer_pos+1) << 4) + as_hex(parser.buffer, parser.buffer_pos+2)) - - // If it is the leading octet, determine the length of the UTF-8 sequence. - if w == 1024 { - w = width(octet) - if w == 0 { - return yaml_parser_set_scanner_tag_error(parser, directive, - start_mark, "found an incorrect leading UTF-8 octet") - } - } else { - // Check if the trailing octet is correct. - if octet&0xC0 != 0x80 { - return yaml_parser_set_scanner_tag_error(parser, directive, - start_mark, "found an incorrect trailing UTF-8 octet") - } - } - - // Copy the octet and move the pointers. - *s = append(*s, octet) - skip(parser) - skip(parser) - skip(parser) - w-- - } - return true -} - -// Scan a block scalar. -func yaml_parser_scan_block_scalar(parser *yaml_parser_t, token *yaml_token_t, literal bool) bool { - // Eat the indicator '|' or '>'. - start_mark := parser.mark - skip(parser) - - // Scan the additional block scalar indicators. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - // Check for a chomping indicator. - var chomping, increment int - if parser.buffer[parser.buffer_pos] == '+' || parser.buffer[parser.buffer_pos] == '-' { - // Set the chomping method and eat the indicator. - if parser.buffer[parser.buffer_pos] == '+' { - chomping = +1 - } else { - chomping = -1 - } - skip(parser) - - // Check for an indentation indicator. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - if is_digit(parser.buffer, parser.buffer_pos) { - // Check that the indentation is greater than 0. - if parser.buffer[parser.buffer_pos] == '0' { - yaml_parser_set_scanner_error(parser, "while scanning a block scalar", - start_mark, "found an indentation indicator equal to 0") - return false - } - - // Get the indentation level and eat the indicator. - increment = as_digit(parser.buffer, parser.buffer_pos) - skip(parser) - } - - } else if is_digit(parser.buffer, parser.buffer_pos) { - // Do the same as above, but in the opposite order. - - if parser.buffer[parser.buffer_pos] == '0' { - yaml_parser_set_scanner_error(parser, "while scanning a block scalar", - start_mark, "found an indentation indicator equal to 0") - return false - } - increment = as_digit(parser.buffer, parser.buffer_pos) - skip(parser) - - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - if parser.buffer[parser.buffer_pos] == '+' || parser.buffer[parser.buffer_pos] == '-' { - if parser.buffer[parser.buffer_pos] == '+' { - chomping = +1 - } else { - chomping = -1 - } - skip(parser) - } - } - - // Eat whitespaces and comments to the end of the line. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - for is_blank(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - if parser.buffer[parser.buffer_pos] == '#' { - for !is_breakz(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - } - - // Check if we are at the end of the line. - if !is_breakz(parser.buffer, parser.buffer_pos) { - yaml_parser_set_scanner_error(parser, "while scanning a block scalar", - start_mark, "did not find expected comment or line break") - return false - } - - // Eat a line break. - if is_break(parser.buffer, parser.buffer_pos) { - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - skip_line(parser) - } - - end_mark := parser.mark - - // Set the indentation level if it was specified. - var indent int - if increment > 0 { - if parser.indent >= 0 { - indent = parser.indent + increment - } else { - indent = increment - } - } - - // Scan the leading line breaks and determine the indentation level if needed. - var s, leading_break, trailing_breaks []byte - if !yaml_parser_scan_block_scalar_breaks(parser, &indent, &trailing_breaks, start_mark, &end_mark) { - return false - } - - // Scan the block scalar content. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - var leading_blank, trailing_blank bool - for parser.mark.column == indent && !is_z(parser.buffer, parser.buffer_pos) { - // We are at the beginning of a non-empty line. - - // Is it a trailing whitespace? - trailing_blank = is_blank(parser.buffer, parser.buffer_pos) - - // Check if we need to fold the leading line break. - if !literal && !leading_blank && !trailing_blank && len(leading_break) > 0 && leading_break[0] == '\n' { - // Do we need to join the lines by space? - if len(trailing_breaks) == 0 { - s = append(s, ' ') - } - } else { - s = append(s, leading_break...) - } - leading_break = leading_break[:0] - - // Append the remaining line breaks. - s = append(s, trailing_breaks...) - trailing_breaks = trailing_breaks[:0] - - // Is it a leading whitespace? - leading_blank = is_blank(parser.buffer, parser.buffer_pos) - - // Consume the current line. - for !is_breakz(parser.buffer, parser.buffer_pos) { - s = read(parser, s) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Consume the line break. - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - - leading_break = read_line(parser, leading_break) - - // Eat the following indentation spaces and line breaks. - if !yaml_parser_scan_block_scalar_breaks(parser, &indent, &trailing_breaks, start_mark, &end_mark) { - return false - } - } - - // Chomp the tail. - if chomping != -1 { - s = append(s, leading_break...) - } - if chomping == 1 { - s = append(s, trailing_breaks...) - } - - // Create a token. - *token = yaml_token_t{ - typ: yaml_SCALAR_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - value: s, - style: yaml_LITERAL_SCALAR_STYLE, - } - if !literal { - token.style = yaml_FOLDED_SCALAR_STYLE - } - return true -} - -// Scan indentation spaces and line breaks for a block scalar. Determine the -// indentation level if needed. -func yaml_parser_scan_block_scalar_breaks(parser *yaml_parser_t, indent *int, breaks *[]byte, start_mark yaml_mark_t, end_mark *yaml_mark_t) bool { - *end_mark = parser.mark - - // Eat the indentation spaces and line breaks. - max_indent := 0 - for { - // Eat the indentation spaces. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - for (*indent == 0 || parser.mark.column < *indent) && is_space(parser.buffer, parser.buffer_pos) { - skip(parser) - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - if parser.mark.column > max_indent { - max_indent = parser.mark.column - } - - // Check for a tab character messing the indentation. - if (*indent == 0 || parser.mark.column < *indent) && is_tab(parser.buffer, parser.buffer_pos) { - return yaml_parser_set_scanner_error(parser, "while scanning a block scalar", - start_mark, "found a tab character where an indentation space is expected") - } - - // Have we found a non-empty line? - if !is_break(parser.buffer, parser.buffer_pos) { - break - } - - // Consume the line break. - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - // [Go] Should really be returning breaks instead. - *breaks = read_line(parser, *breaks) - *end_mark = parser.mark - } - - // Determine the indentation level if needed. - if *indent == 0 { - *indent = max_indent - if *indent < parser.indent+1 { - *indent = parser.indent + 1 - } - if *indent < 1 { - *indent = 1 - } - } - return true -} - -// Scan a quoted scalar. -func yaml_parser_scan_flow_scalar(parser *yaml_parser_t, token *yaml_token_t, single bool) bool { - // Eat the left quote. - start_mark := parser.mark - skip(parser) - - // Consume the content of the quoted scalar. - var s, leading_break, trailing_breaks, whitespaces []byte - for { - // Check that there are no document indicators at the beginning of the line. - if parser.unread < 4 && !yaml_parser_update_buffer(parser, 4) { - return false - } - - if parser.mark.column == 0 && - ((parser.buffer[parser.buffer_pos+0] == '-' && - parser.buffer[parser.buffer_pos+1] == '-' && - parser.buffer[parser.buffer_pos+2] == '-') || - (parser.buffer[parser.buffer_pos+0] == '.' && - parser.buffer[parser.buffer_pos+1] == '.' && - parser.buffer[parser.buffer_pos+2] == '.')) && - is_blankz(parser.buffer, parser.buffer_pos+3) { - yaml_parser_set_scanner_error(parser, "while scanning a quoted scalar", - start_mark, "found unexpected document indicator") - return false - } - - // Check for EOF. - if is_z(parser.buffer, parser.buffer_pos) { - yaml_parser_set_scanner_error(parser, "while scanning a quoted scalar", - start_mark, "found unexpected end of stream") - return false - } - - // Consume non-blank characters. - leading_blanks := false - for !is_blankz(parser.buffer, parser.buffer_pos) { - if single && parser.buffer[parser.buffer_pos] == '\'' && parser.buffer[parser.buffer_pos+1] == '\'' { - // Is is an escaped single quote. - s = append(s, '\'') - skip(parser) - skip(parser) - - } else if single && parser.buffer[parser.buffer_pos] == '\'' { - // It is a right single quote. - break - } else if !single && parser.buffer[parser.buffer_pos] == '"' { - // It is a right double quote. - break - - } else if !single && parser.buffer[parser.buffer_pos] == '\\' && is_break(parser.buffer, parser.buffer_pos+1) { - // It is an escaped line break. - if parser.unread < 3 && !yaml_parser_update_buffer(parser, 3) { - return false - } - skip(parser) - skip_line(parser) - leading_blanks = true - break - - } else if !single && parser.buffer[parser.buffer_pos] == '\\' { - // It is an escape sequence. - code_length := 0 - - // Check the escape character. - switch parser.buffer[parser.buffer_pos+1] { - case '0': - s = append(s, 0) - case 'a': - s = append(s, '\x07') - case 'b': - s = append(s, '\x08') - case 't', '\t': - s = append(s, '\x09') - case 'n': - s = append(s, '\x0A') - case 'v': - s = append(s, '\x0B') - case 'f': - s = append(s, '\x0C') - case 'r': - s = append(s, '\x0D') - case 'e': - s = append(s, '\x1B') - case ' ': - s = append(s, '\x20') - case '"': - s = append(s, '"') - case '\'': - s = append(s, '\'') - case '\\': - s = append(s, '\\') - case 'N': // NEL (#x85) - s = append(s, '\xC2') - s = append(s, '\x85') - case '_': // #xA0 - s = append(s, '\xC2') - s = append(s, '\xA0') - case 'L': // LS (#x2028) - s = append(s, '\xE2') - s = append(s, '\x80') - s = append(s, '\xA8') - case 'P': // PS (#x2029) - s = append(s, '\xE2') - s = append(s, '\x80') - s = append(s, '\xA9') - case 'x': - code_length = 2 - case 'u': - code_length = 4 - case 'U': - code_length = 8 - default: - yaml_parser_set_scanner_error(parser, "while parsing a quoted scalar", - start_mark, "found unknown escape character") - return false - } - - skip(parser) - skip(parser) - - // Consume an arbitrary escape code. - if code_length > 0 { - var value int - - // Scan the character value. - if parser.unread < code_length && !yaml_parser_update_buffer(parser, code_length) { - return false - } - for k := 0; k < code_length; k++ { - if !is_hex(parser.buffer, parser.buffer_pos+k) { - yaml_parser_set_scanner_error(parser, "while parsing a quoted scalar", - start_mark, "did not find expected hexdecimal number") - return false - } - value = (value << 4) + as_hex(parser.buffer, parser.buffer_pos+k) - } - - // Check the value and write the character. - if (value >= 0xD800 && value <= 0xDFFF) || value > 0x10FFFF { - yaml_parser_set_scanner_error(parser, "while parsing a quoted scalar", - start_mark, "found invalid Unicode character escape code") - return false - } - if value <= 0x7F { - s = append(s, byte(value)) - } else if value <= 0x7FF { - s = append(s, byte(0xC0+(value>>6))) - s = append(s, byte(0x80+(value&0x3F))) - } else if value <= 0xFFFF { - s = append(s, byte(0xE0+(value>>12))) - s = append(s, byte(0x80+((value>>6)&0x3F))) - s = append(s, byte(0x80+(value&0x3F))) - } else { - s = append(s, byte(0xF0+(value>>18))) - s = append(s, byte(0x80+((value>>12)&0x3F))) - s = append(s, byte(0x80+((value>>6)&0x3F))) - s = append(s, byte(0x80+(value&0x3F))) - } - - // Advance the pointer. - for k := 0; k < code_length; k++ { - skip(parser) - } - } - } else { - // It is a non-escaped non-blank character. - s = read(parser, s) - } - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - } - - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - // Check if we are at the end of the scalar. - if single { - if parser.buffer[parser.buffer_pos] == '\'' { - break - } - } else { - if parser.buffer[parser.buffer_pos] == '"' { - break - } - } - - // Consume blank characters. - for is_blank(parser.buffer, parser.buffer_pos) || is_break(parser.buffer, parser.buffer_pos) { - if is_blank(parser.buffer, parser.buffer_pos) { - // Consume a space or a tab character. - if !leading_blanks { - whitespaces = read(parser, whitespaces) - } else { - skip(parser) - } - } else { - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - - // Check if it is a first line break. - if !leading_blanks { - whitespaces = whitespaces[:0] - leading_break = read_line(parser, leading_break) - leading_blanks = true - } else { - trailing_breaks = read_line(parser, trailing_breaks) - } - } - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Join the whitespaces or fold line breaks. - if leading_blanks { - // Do we need to fold line breaks? - if len(leading_break) > 0 && leading_break[0] == '\n' { - if len(trailing_breaks) == 0 { - s = append(s, ' ') - } else { - s = append(s, trailing_breaks...) - } - } else { - s = append(s, leading_break...) - s = append(s, trailing_breaks...) - } - trailing_breaks = trailing_breaks[:0] - leading_break = leading_break[:0] - } else { - s = append(s, whitespaces...) - whitespaces = whitespaces[:0] - } - } - - // Eat the right quote. - skip(parser) - end_mark := parser.mark - - // Create a token. - *token = yaml_token_t{ - typ: yaml_SCALAR_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - value: s, - style: yaml_SINGLE_QUOTED_SCALAR_STYLE, - } - if !single { - token.style = yaml_DOUBLE_QUOTED_SCALAR_STYLE - } - return true -} - -// Scan a plain scalar. -func yaml_parser_scan_plain_scalar(parser *yaml_parser_t, token *yaml_token_t) bool { - - var s, leading_break, trailing_breaks, whitespaces []byte - var leading_blanks bool - var indent = parser.indent + 1 - - start_mark := parser.mark - end_mark := parser.mark - - // Consume the content of the plain scalar. - for { - // Check for a document indicator. - if parser.unread < 4 && !yaml_parser_update_buffer(parser, 4) { - return false - } - if parser.mark.column == 0 && - ((parser.buffer[parser.buffer_pos+0] == '-' && - parser.buffer[parser.buffer_pos+1] == '-' && - parser.buffer[parser.buffer_pos+2] == '-') || - (parser.buffer[parser.buffer_pos+0] == '.' && - parser.buffer[parser.buffer_pos+1] == '.' && - parser.buffer[parser.buffer_pos+2] == '.')) && - is_blankz(parser.buffer, parser.buffer_pos+3) { - break - } - - // Check for a comment. - if parser.buffer[parser.buffer_pos] == '#' { - break - } - - // Consume non-blank characters. - for !is_blankz(parser.buffer, parser.buffer_pos) { - - // Check for indicators that may end a plain scalar. - if (parser.buffer[parser.buffer_pos] == ':' && is_blankz(parser.buffer, parser.buffer_pos+1)) || - (parser.flow_level > 0 && - (parser.buffer[parser.buffer_pos] == ',' || - parser.buffer[parser.buffer_pos] == '?' || parser.buffer[parser.buffer_pos] == '[' || - parser.buffer[parser.buffer_pos] == ']' || parser.buffer[parser.buffer_pos] == '{' || - parser.buffer[parser.buffer_pos] == '}')) { - break - } - - // Check if we need to join whitespaces and breaks. - if leading_blanks || len(whitespaces) > 0 { - if leading_blanks { - // Do we need to fold line breaks? - if leading_break[0] == '\n' { - if len(trailing_breaks) == 0 { - s = append(s, ' ') - } else { - s = append(s, trailing_breaks...) - } - } else { - s = append(s, leading_break...) - s = append(s, trailing_breaks...) - } - trailing_breaks = trailing_breaks[:0] - leading_break = leading_break[:0] - leading_blanks = false - } else { - s = append(s, whitespaces...) - whitespaces = whitespaces[:0] - } - } - - // Copy the character. - s = read(parser, s) - - end_mark = parser.mark - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - } - - // Is it the end? - if !(is_blank(parser.buffer, parser.buffer_pos) || is_break(parser.buffer, parser.buffer_pos)) { - break - } - - // Consume blank characters. - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - - for is_blank(parser.buffer, parser.buffer_pos) || is_break(parser.buffer, parser.buffer_pos) { - if is_blank(parser.buffer, parser.buffer_pos) { - - // Check for tab characters that abuse indentation. - if leading_blanks && parser.mark.column < indent && is_tab(parser.buffer, parser.buffer_pos) { - yaml_parser_set_scanner_error(parser, "while scanning a plain scalar", - start_mark, "found a tab character that violates indentation") - return false - } - - // Consume a space or a tab character. - if !leading_blanks { - whitespaces = read(parser, whitespaces) - } else { - skip(parser) - } - } else { - if parser.unread < 2 && !yaml_parser_update_buffer(parser, 2) { - return false - } - - // Check if it is a first line break. - if !leading_blanks { - whitespaces = whitespaces[:0] - leading_break = read_line(parser, leading_break) - leading_blanks = true - } else { - trailing_breaks = read_line(parser, trailing_breaks) - } - } - if parser.unread < 1 && !yaml_parser_update_buffer(parser, 1) { - return false - } - } - - // Check indentation level. - if parser.flow_level == 0 && parser.mark.column < indent { - break - } - } - - // Create a token. - *token = yaml_token_t{ - typ: yaml_SCALAR_TOKEN, - start_mark: start_mark, - end_mark: end_mark, - value: s, - style: yaml_PLAIN_SCALAR_STYLE, - } - - // Note that we change the 'simple_key_allowed' flag. - if leading_blanks { - parser.simple_key_allowed = true - } - return true -} diff --git a/vendor/gopkg.in/yaml.v2/sorter.go b/vendor/gopkg.in/yaml.v2/sorter.go deleted file mode 100644 index 4c45e660a8..0000000000 --- a/vendor/gopkg.in/yaml.v2/sorter.go +++ /dev/null @@ -1,113 +0,0 @@ -package yaml - -import ( - "reflect" - "unicode" -) - -type keyList []reflect.Value - -func (l keyList) Len() int { return len(l) } -func (l keyList) Swap(i, j int) { l[i], l[j] = l[j], l[i] } -func (l keyList) Less(i, j int) bool { - a := l[i] - b := l[j] - ak := a.Kind() - bk := b.Kind() - for (ak == reflect.Interface || ak == reflect.Ptr) && !a.IsNil() { - a = a.Elem() - ak = a.Kind() - } - for (bk == reflect.Interface || bk == reflect.Ptr) && !b.IsNil() { - b = b.Elem() - bk = b.Kind() - } - af, aok := keyFloat(a) - bf, bok := keyFloat(b) - if aok && bok { - if af != bf { - return af < bf - } - if ak != bk { - return ak < bk - } - return numLess(a, b) - } - if ak != reflect.String || bk != reflect.String { - return ak < bk - } - ar, br := []rune(a.String()), []rune(b.String()) - for i := 0; i < len(ar) && i < len(br); i++ { - if ar[i] == br[i] { - continue - } - al := unicode.IsLetter(ar[i]) - bl := unicode.IsLetter(br[i]) - if al && bl { - return ar[i] < br[i] - } - if al || bl { - return bl - } - var ai, bi int - var an, bn int64 - if ar[i] == '0' || br[i] == '0' { - for j := i-1; j >= 0 && unicode.IsDigit(ar[j]); j-- { - if ar[j] != '0' { - an = 1 - bn = 1 - break - } - } - } - for ai = i; ai < len(ar) && unicode.IsDigit(ar[ai]); ai++ { - an = an*10 + int64(ar[ai]-'0') - } - for bi = i; bi < len(br) && unicode.IsDigit(br[bi]); bi++ { - bn = bn*10 + int64(br[bi]-'0') - } - if an != bn { - return an < bn - } - if ai != bi { - return ai < bi - } - return ar[i] < br[i] - } - return len(ar) < len(br) -} - -// keyFloat returns a float value for v if it is a number/bool -// and whether it is a number/bool or not. -func keyFloat(v reflect.Value) (f float64, ok bool) { - switch v.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return float64(v.Int()), true - case reflect.Float32, reflect.Float64: - return v.Float(), true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return float64(v.Uint()), true - case reflect.Bool: - if v.Bool() { - return 1, true - } - return 0, true - } - return 0, false -} - -// numLess returns whether a < b. -// a and b must necessarily have the same kind. -func numLess(a, b reflect.Value) bool { - switch a.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() < b.Int() - case reflect.Float32, reflect.Float64: - return a.Float() < b.Float() - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() < b.Uint() - case reflect.Bool: - return !a.Bool() && b.Bool() - } - panic("not a number") -} diff --git a/vendor/gopkg.in/yaml.v2/writerc.go b/vendor/gopkg.in/yaml.v2/writerc.go deleted file mode 100644 index a2dde608cb..0000000000 --- a/vendor/gopkg.in/yaml.v2/writerc.go +++ /dev/null @@ -1,26 +0,0 @@ -package yaml - -// Set the writer error and return false. -func yaml_emitter_set_writer_error(emitter *yaml_emitter_t, problem string) bool { - emitter.error = yaml_WRITER_ERROR - emitter.problem = problem - return false -} - -// Flush the output buffer. -func yaml_emitter_flush(emitter *yaml_emitter_t) bool { - if emitter.write_handler == nil { - panic("write handler not set") - } - - // Check if the buffer is empty. - if emitter.buffer_pos == 0 { - return true - } - - if err := emitter.write_handler(emitter, emitter.buffer[:emitter.buffer_pos]); err != nil { - return yaml_emitter_set_writer_error(emitter, "write error: "+err.Error()) - } - emitter.buffer_pos = 0 - return true -} diff --git a/vendor/gopkg.in/yaml.v2/yaml.go b/vendor/gopkg.in/yaml.v2/yaml.go deleted file mode 100644 index de85aa4cdb..0000000000 --- a/vendor/gopkg.in/yaml.v2/yaml.go +++ /dev/null @@ -1,466 +0,0 @@ -// Package yaml implements YAML support for the Go language. -// -// Source code and other details for the project are available at GitHub: -// -// https://github.com/go-yaml/yaml -// -package yaml - -import ( - "errors" - "fmt" - "io" - "reflect" - "strings" - "sync" -) - -// MapSlice encodes and decodes as a YAML map. -// The order of keys is preserved when encoding and decoding. -type MapSlice []MapItem - -// MapItem is an item in a MapSlice. -type MapItem struct { - Key, Value interface{} -} - -// The Unmarshaler interface may be implemented by types to customize their -// behavior when being unmarshaled from a YAML document. The UnmarshalYAML -// method receives a function that may be called to unmarshal the original -// YAML value into a field or variable. It is safe to call the unmarshal -// function parameter more than once if necessary. -type Unmarshaler interface { - UnmarshalYAML(unmarshal func(interface{}) error) error -} - -// The Marshaler interface may be implemented by types to customize their -// behavior when being marshaled into a YAML document. The returned value -// is marshaled in place of the original value implementing Marshaler. -// -// If an error is returned by MarshalYAML, the marshaling procedure stops -// and returns with the provided error. -type Marshaler interface { - MarshalYAML() (interface{}, error) -} - -// Unmarshal decodes the first document found within the in byte slice -// and assigns decoded values into the out value. -// -// Maps and pointers (to a struct, string, int, etc) are accepted as out -// values. If an internal pointer within a struct is not initialized, -// the yaml package will initialize it if necessary for unmarshalling -// the provided data. The out parameter must not be nil. -// -// The type of the decoded values should be compatible with the respective -// values in out. If one or more values cannot be decoded due to a type -// mismatches, decoding continues partially until the end of the YAML -// content, and a *yaml.TypeError is returned with details for all -// missed values. -// -// Struct fields are only unmarshalled if they are exported (have an -// upper case first letter), and are unmarshalled using the field name -// lowercased as the default key. Custom keys may be defined via the -// "yaml" name in the field tag: the content preceding the first comma -// is used as the key, and the following comma-separated options are -// used to tweak the marshalling process (see Marshal). -// Conflicting names result in a runtime error. -// -// For example: -// -// type T struct { -// F int `yaml:"a,omitempty"` -// B int -// } -// var t T -// yaml.Unmarshal([]byte("a: 1\nb: 2"), &t) -// -// See the documentation of Marshal for the format of tags and a list of -// supported tag options. -// -func Unmarshal(in []byte, out interface{}) (err error) { - return unmarshal(in, out, false) -} - -// UnmarshalStrict is like Unmarshal except that any fields that are found -// in the data that do not have corresponding struct members, or mapping -// keys that are duplicates, will result in -// an error. -func UnmarshalStrict(in []byte, out interface{}) (err error) { - return unmarshal(in, out, true) -} - -// A Decorder reads and decodes YAML values from an input stream. -type Decoder struct { - strict bool - parser *parser -} - -// NewDecoder returns a new decoder that reads from r. -// -// The decoder introduces its own buffering and may read -// data from r beyond the YAML values requested. -func NewDecoder(r io.Reader) *Decoder { - return &Decoder{ - parser: newParserFromReader(r), - } -} - -// SetStrict sets whether strict decoding behaviour is enabled when -// decoding items in the data (see UnmarshalStrict). By default, decoding is not strict. -func (dec *Decoder) SetStrict(strict bool) { - dec.strict = strict -} - -// Decode reads the next YAML-encoded value from its input -// and stores it in the value pointed to by v. -// -// See the documentation for Unmarshal for details about the -// conversion of YAML into a Go value. -func (dec *Decoder) Decode(v interface{}) (err error) { - d := newDecoder(dec.strict) - defer handleErr(&err) - node := dec.parser.parse() - if node == nil { - return io.EOF - } - out := reflect.ValueOf(v) - if out.Kind() == reflect.Ptr && !out.IsNil() { - out = out.Elem() - } - d.unmarshal(node, out) - if len(d.terrors) > 0 { - return &TypeError{d.terrors} - } - return nil -} - -func unmarshal(in []byte, out interface{}, strict bool) (err error) { - defer handleErr(&err) - d := newDecoder(strict) - p := newParser(in) - defer p.destroy() - node := p.parse() - if node != nil { - v := reflect.ValueOf(out) - if v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - d.unmarshal(node, v) - } - if len(d.terrors) > 0 { - return &TypeError{d.terrors} - } - return nil -} - -// Marshal serializes the value provided into a YAML document. The structure -// of the generated document will reflect the structure of the value itself. -// Maps and pointers (to struct, string, int, etc) are accepted as the in value. -// -// Struct fields are only marshalled if they are exported (have an upper case -// first letter), and are marshalled using the field name lowercased as the -// default key. Custom keys may be defined via the "yaml" name in the field -// tag: the content preceding the first comma is used as the key, and the -// following comma-separated options are used to tweak the marshalling process. -// Conflicting names result in a runtime error. -// -// The field tag format accepted is: -// -// `(...) yaml:"[][,[,]]" (...)` -// -// The following flags are currently supported: -// -// omitempty Only include the field if it's not set to the zero -// value for the type or to empty slices or maps. -// Zero valued structs will be omitted if all their public -// fields are zero, unless they implement an IsZero -// method (see the IsZeroer interface type), in which -// case the field will be included if that method returns true. -// -// flow Marshal using a flow style (useful for structs, -// sequences and maps). -// -// inline Inline the field, which must be a struct or a map, -// causing all of its fields or keys to be processed as if -// they were part of the outer struct. For maps, keys must -// not conflict with the yaml keys of other struct fields. -// -// In addition, if the key is "-", the field is ignored. -// -// For example: -// -// type T struct { -// F int `yaml:"a,omitempty"` -// B int -// } -// yaml.Marshal(&T{B: 2}) // Returns "b: 2\n" -// yaml.Marshal(&T{F: 1}} // Returns "a: 1\nb: 0\n" -// -func Marshal(in interface{}) (out []byte, err error) { - defer handleErr(&err) - e := newEncoder() - defer e.destroy() - e.marshalDoc("", reflect.ValueOf(in)) - e.finish() - out = e.out - return -} - -// An Encoder writes YAML values to an output stream. -type Encoder struct { - encoder *encoder -} - -// NewEncoder returns a new encoder that writes to w. -// The Encoder should be closed after use to flush all data -// to w. -func NewEncoder(w io.Writer) *Encoder { - return &Encoder{ - encoder: newEncoderWithWriter(w), - } -} - -// Encode writes the YAML encoding of v to the stream. -// If multiple items are encoded to the stream, the -// second and subsequent document will be preceded -// with a "---" document separator, but the first will not. -// -// See the documentation for Marshal for details about the conversion of Go -// values to YAML. -func (e *Encoder) Encode(v interface{}) (err error) { - defer handleErr(&err) - e.encoder.marshalDoc("", reflect.ValueOf(v)) - return nil -} - -// Close closes the encoder by writing any remaining data. -// It does not write a stream terminating string "...". -func (e *Encoder) Close() (err error) { - defer handleErr(&err) - e.encoder.finish() - return nil -} - -func handleErr(err *error) { - if v := recover(); v != nil { - if e, ok := v.(yamlError); ok { - *err = e.err - } else { - panic(v) - } - } -} - -type yamlError struct { - err error -} - -func fail(err error) { - panic(yamlError{err}) -} - -func failf(format string, args ...interface{}) { - panic(yamlError{fmt.Errorf("yaml: "+format, args...)}) -} - -// A TypeError is returned by Unmarshal when one or more fields in -// the YAML document cannot be properly decoded into the requested -// types. When this error is returned, the value is still -// unmarshaled partially. -type TypeError struct { - Errors []string -} - -func (e *TypeError) Error() string { - return fmt.Sprintf("yaml: unmarshal errors:\n %s", strings.Join(e.Errors, "\n ")) -} - -// -------------------------------------------------------------------------- -// Maintain a mapping of keys to structure field indexes - -// The code in this section was copied from mgo/bson. - -// structInfo holds details for the serialization of fields of -// a given struct. -type structInfo struct { - FieldsMap map[string]fieldInfo - FieldsList []fieldInfo - - // InlineMap is the number of the field in the struct that - // contains an ,inline map, or -1 if there's none. - InlineMap int -} - -type fieldInfo struct { - Key string - Num int - OmitEmpty bool - Flow bool - // Id holds the unique field identifier, so we can cheaply - // check for field duplicates without maintaining an extra map. - Id int - - // Inline holds the field index if the field is part of an inlined struct. - Inline []int -} - -var structMap = make(map[reflect.Type]*structInfo) -var fieldMapMutex sync.RWMutex - -func getStructInfo(st reflect.Type) (*structInfo, error) { - fieldMapMutex.RLock() - sinfo, found := structMap[st] - fieldMapMutex.RUnlock() - if found { - return sinfo, nil - } - - n := st.NumField() - fieldsMap := make(map[string]fieldInfo) - fieldsList := make([]fieldInfo, 0, n) - inlineMap := -1 - for i := 0; i != n; i++ { - field := st.Field(i) - if field.PkgPath != "" && !field.Anonymous { - continue // Private field - } - - info := fieldInfo{Num: i} - - tag := field.Tag.Get("yaml") - if tag == "" && strings.Index(string(field.Tag), ":") < 0 { - tag = string(field.Tag) - } - if tag == "-" { - continue - } - - inline := false - fields := strings.Split(tag, ",") - if len(fields) > 1 { - for _, flag := range fields[1:] { - switch flag { - case "omitempty": - info.OmitEmpty = true - case "flow": - info.Flow = true - case "inline": - inline = true - default: - return nil, errors.New(fmt.Sprintf("Unsupported flag %q in tag %q of type %s", flag, tag, st)) - } - } - tag = fields[0] - } - - if inline { - switch field.Type.Kind() { - case reflect.Map: - if inlineMap >= 0 { - return nil, errors.New("Multiple ,inline maps in struct " + st.String()) - } - if field.Type.Key() != reflect.TypeOf("") { - return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) - } - inlineMap = info.Num - case reflect.Struct: - sinfo, err := getStructInfo(field.Type) - if err != nil { - return nil, err - } - for _, finfo := range sinfo.FieldsList { - if _, found := fieldsMap[finfo.Key]; found { - msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String() - return nil, errors.New(msg) - } - if finfo.Inline == nil { - finfo.Inline = []int{i, finfo.Num} - } else { - finfo.Inline = append([]int{i}, finfo.Inline...) - } - finfo.Id = len(fieldsList) - fieldsMap[finfo.Key] = finfo - fieldsList = append(fieldsList, finfo) - } - default: - //return nil, errors.New("Option ,inline needs a struct value or map field") - return nil, errors.New("Option ,inline needs a struct value field") - } - continue - } - - if tag != "" { - info.Key = tag - } else { - info.Key = strings.ToLower(field.Name) - } - - if _, found = fieldsMap[info.Key]; found { - msg := "Duplicated key '" + info.Key + "' in struct " + st.String() - return nil, errors.New(msg) - } - - info.Id = len(fieldsList) - fieldsList = append(fieldsList, info) - fieldsMap[info.Key] = info - } - - sinfo = &structInfo{ - FieldsMap: fieldsMap, - FieldsList: fieldsList, - InlineMap: inlineMap, - } - - fieldMapMutex.Lock() - structMap[st] = sinfo - fieldMapMutex.Unlock() - return sinfo, nil -} - -// IsZeroer is used to check whether an object is zero to -// determine whether it should be omitted when marshaling -// with the omitempty flag. One notable implementation -// is time.Time. -type IsZeroer interface { - IsZero() bool -} - -func isZero(v reflect.Value) bool { - kind := v.Kind() - if z, ok := v.Interface().(IsZeroer); ok { - if (kind == reflect.Ptr || kind == reflect.Interface) && v.IsNil() { - return true - } - return z.IsZero() - } - switch kind { - case reflect.String: - return len(v.String()) == 0 - case reflect.Interface, reflect.Ptr: - return v.IsNil() - case reflect.Slice: - return v.Len() == 0 - case reflect.Map: - return v.Len() == 0 - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Float32, reflect.Float64: - return v.Float() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return v.Uint() == 0 - case reflect.Bool: - return !v.Bool() - case reflect.Struct: - vt := v.Type() - for i := v.NumField() - 1; i >= 0; i-- { - if vt.Field(i).PkgPath != "" { - continue // Private field - } - if !isZero(v.Field(i)) { - return false - } - } - return true - } - return false -} diff --git a/vendor/gopkg.in/yaml.v2/yamlh.go b/vendor/gopkg.in/yaml.v2/yamlh.go deleted file mode 100644 index e25cee563b..0000000000 --- a/vendor/gopkg.in/yaml.v2/yamlh.go +++ /dev/null @@ -1,738 +0,0 @@ -package yaml - -import ( - "fmt" - "io" -) - -// The version directive data. -type yaml_version_directive_t struct { - major int8 // The major version number. - minor int8 // The minor version number. -} - -// The tag directive data. -type yaml_tag_directive_t struct { - handle []byte // The tag handle. - prefix []byte // The tag prefix. -} - -type yaml_encoding_t int - -// The stream encoding. -const ( - // Let the parser choose the encoding. - yaml_ANY_ENCODING yaml_encoding_t = iota - - yaml_UTF8_ENCODING // The default UTF-8 encoding. - yaml_UTF16LE_ENCODING // The UTF-16-LE encoding with BOM. - yaml_UTF16BE_ENCODING // The UTF-16-BE encoding with BOM. -) - -type yaml_break_t int - -// Line break types. -const ( - // Let the parser choose the break type. - yaml_ANY_BREAK yaml_break_t = iota - - yaml_CR_BREAK // Use CR for line breaks (Mac style). - yaml_LN_BREAK // Use LN for line breaks (Unix style). - yaml_CRLN_BREAK // Use CR LN for line breaks (DOS style). -) - -type yaml_error_type_t int - -// Many bad things could happen with the parser and emitter. -const ( - // No error is produced. - yaml_NO_ERROR yaml_error_type_t = iota - - yaml_MEMORY_ERROR // Cannot allocate or reallocate a block of memory. - yaml_READER_ERROR // Cannot read or decode the input stream. - yaml_SCANNER_ERROR // Cannot scan the input stream. - yaml_PARSER_ERROR // Cannot parse the input stream. - yaml_COMPOSER_ERROR // Cannot compose a YAML document. - yaml_WRITER_ERROR // Cannot write to the output stream. - yaml_EMITTER_ERROR // Cannot emit a YAML stream. -) - -// The pointer position. -type yaml_mark_t struct { - index int // The position index. - line int // The position line. - column int // The position column. -} - -// Node Styles - -type yaml_style_t int8 - -type yaml_scalar_style_t yaml_style_t - -// Scalar styles. -const ( - // Let the emitter choose the style. - yaml_ANY_SCALAR_STYLE yaml_scalar_style_t = iota - - yaml_PLAIN_SCALAR_STYLE // The plain scalar style. - yaml_SINGLE_QUOTED_SCALAR_STYLE // The single-quoted scalar style. - yaml_DOUBLE_QUOTED_SCALAR_STYLE // The double-quoted scalar style. - yaml_LITERAL_SCALAR_STYLE // The literal scalar style. - yaml_FOLDED_SCALAR_STYLE // The folded scalar style. -) - -type yaml_sequence_style_t yaml_style_t - -// Sequence styles. -const ( - // Let the emitter choose the style. - yaml_ANY_SEQUENCE_STYLE yaml_sequence_style_t = iota - - yaml_BLOCK_SEQUENCE_STYLE // The block sequence style. - yaml_FLOW_SEQUENCE_STYLE // The flow sequence style. -) - -type yaml_mapping_style_t yaml_style_t - -// Mapping styles. -const ( - // Let the emitter choose the style. - yaml_ANY_MAPPING_STYLE yaml_mapping_style_t = iota - - yaml_BLOCK_MAPPING_STYLE // The block mapping style. - yaml_FLOW_MAPPING_STYLE // The flow mapping style. -) - -// Tokens - -type yaml_token_type_t int - -// Token types. -const ( - // An empty token. - yaml_NO_TOKEN yaml_token_type_t = iota - - yaml_STREAM_START_TOKEN // A STREAM-START token. - yaml_STREAM_END_TOKEN // A STREAM-END token. - - yaml_VERSION_DIRECTIVE_TOKEN // A VERSION-DIRECTIVE token. - yaml_TAG_DIRECTIVE_TOKEN // A TAG-DIRECTIVE token. - yaml_DOCUMENT_START_TOKEN // A DOCUMENT-START token. - yaml_DOCUMENT_END_TOKEN // A DOCUMENT-END token. - - yaml_BLOCK_SEQUENCE_START_TOKEN // A BLOCK-SEQUENCE-START token. - yaml_BLOCK_MAPPING_START_TOKEN // A BLOCK-SEQUENCE-END token. - yaml_BLOCK_END_TOKEN // A BLOCK-END token. - - yaml_FLOW_SEQUENCE_START_TOKEN // A FLOW-SEQUENCE-START token. - yaml_FLOW_SEQUENCE_END_TOKEN // A FLOW-SEQUENCE-END token. - yaml_FLOW_MAPPING_START_TOKEN // A FLOW-MAPPING-START token. - yaml_FLOW_MAPPING_END_TOKEN // A FLOW-MAPPING-END token. - - yaml_BLOCK_ENTRY_TOKEN // A BLOCK-ENTRY token. - yaml_FLOW_ENTRY_TOKEN // A FLOW-ENTRY token. - yaml_KEY_TOKEN // A KEY token. - yaml_VALUE_TOKEN // A VALUE token. - - yaml_ALIAS_TOKEN // An ALIAS token. - yaml_ANCHOR_TOKEN // An ANCHOR token. - yaml_TAG_TOKEN // A TAG token. - yaml_SCALAR_TOKEN // A SCALAR token. -) - -func (tt yaml_token_type_t) String() string { - switch tt { - case yaml_NO_TOKEN: - return "yaml_NO_TOKEN" - case yaml_STREAM_START_TOKEN: - return "yaml_STREAM_START_TOKEN" - case yaml_STREAM_END_TOKEN: - return "yaml_STREAM_END_TOKEN" - case yaml_VERSION_DIRECTIVE_TOKEN: - return "yaml_VERSION_DIRECTIVE_TOKEN" - case yaml_TAG_DIRECTIVE_TOKEN: - return "yaml_TAG_DIRECTIVE_TOKEN" - case yaml_DOCUMENT_START_TOKEN: - return "yaml_DOCUMENT_START_TOKEN" - case yaml_DOCUMENT_END_TOKEN: - return "yaml_DOCUMENT_END_TOKEN" - case yaml_BLOCK_SEQUENCE_START_TOKEN: - return "yaml_BLOCK_SEQUENCE_START_TOKEN" - case yaml_BLOCK_MAPPING_START_TOKEN: - return "yaml_BLOCK_MAPPING_START_TOKEN" - case yaml_BLOCK_END_TOKEN: - return "yaml_BLOCK_END_TOKEN" - case yaml_FLOW_SEQUENCE_START_TOKEN: - return "yaml_FLOW_SEQUENCE_START_TOKEN" - case yaml_FLOW_SEQUENCE_END_TOKEN: - return "yaml_FLOW_SEQUENCE_END_TOKEN" - case yaml_FLOW_MAPPING_START_TOKEN: - return "yaml_FLOW_MAPPING_START_TOKEN" - case yaml_FLOW_MAPPING_END_TOKEN: - return "yaml_FLOW_MAPPING_END_TOKEN" - case yaml_BLOCK_ENTRY_TOKEN: - return "yaml_BLOCK_ENTRY_TOKEN" - case yaml_FLOW_ENTRY_TOKEN: - return "yaml_FLOW_ENTRY_TOKEN" - case yaml_KEY_TOKEN: - return "yaml_KEY_TOKEN" - case yaml_VALUE_TOKEN: - return "yaml_VALUE_TOKEN" - case yaml_ALIAS_TOKEN: - return "yaml_ALIAS_TOKEN" - case yaml_ANCHOR_TOKEN: - return "yaml_ANCHOR_TOKEN" - case yaml_TAG_TOKEN: - return "yaml_TAG_TOKEN" - case yaml_SCALAR_TOKEN: - return "yaml_SCALAR_TOKEN" - } - return "" -} - -// The token structure. -type yaml_token_t struct { - // The token type. - typ yaml_token_type_t - - // The start/end of the token. - start_mark, end_mark yaml_mark_t - - // The stream encoding (for yaml_STREAM_START_TOKEN). - encoding yaml_encoding_t - - // The alias/anchor/scalar value or tag/tag directive handle - // (for yaml_ALIAS_TOKEN, yaml_ANCHOR_TOKEN, yaml_SCALAR_TOKEN, yaml_TAG_TOKEN, yaml_TAG_DIRECTIVE_TOKEN). - value []byte - - // The tag suffix (for yaml_TAG_TOKEN). - suffix []byte - - // The tag directive prefix (for yaml_TAG_DIRECTIVE_TOKEN). - prefix []byte - - // The scalar style (for yaml_SCALAR_TOKEN). - style yaml_scalar_style_t - - // The version directive major/minor (for yaml_VERSION_DIRECTIVE_TOKEN). - major, minor int8 -} - -// Events - -type yaml_event_type_t int8 - -// Event types. -const ( - // An empty event. - yaml_NO_EVENT yaml_event_type_t = iota - - yaml_STREAM_START_EVENT // A STREAM-START event. - yaml_STREAM_END_EVENT // A STREAM-END event. - yaml_DOCUMENT_START_EVENT // A DOCUMENT-START event. - yaml_DOCUMENT_END_EVENT // A DOCUMENT-END event. - yaml_ALIAS_EVENT // An ALIAS event. - yaml_SCALAR_EVENT // A SCALAR event. - yaml_SEQUENCE_START_EVENT // A SEQUENCE-START event. - yaml_SEQUENCE_END_EVENT // A SEQUENCE-END event. - yaml_MAPPING_START_EVENT // A MAPPING-START event. - yaml_MAPPING_END_EVENT // A MAPPING-END event. -) - -var eventStrings = []string{ - yaml_NO_EVENT: "none", - yaml_STREAM_START_EVENT: "stream start", - yaml_STREAM_END_EVENT: "stream end", - yaml_DOCUMENT_START_EVENT: "document start", - yaml_DOCUMENT_END_EVENT: "document end", - yaml_ALIAS_EVENT: "alias", - yaml_SCALAR_EVENT: "scalar", - yaml_SEQUENCE_START_EVENT: "sequence start", - yaml_SEQUENCE_END_EVENT: "sequence end", - yaml_MAPPING_START_EVENT: "mapping start", - yaml_MAPPING_END_EVENT: "mapping end", -} - -func (e yaml_event_type_t) String() string { - if e < 0 || int(e) >= len(eventStrings) { - return fmt.Sprintf("unknown event %d", e) - } - return eventStrings[e] -} - -// The event structure. -type yaml_event_t struct { - - // The event type. - typ yaml_event_type_t - - // The start and end of the event. - start_mark, end_mark yaml_mark_t - - // The document encoding (for yaml_STREAM_START_EVENT). - encoding yaml_encoding_t - - // The version directive (for yaml_DOCUMENT_START_EVENT). - version_directive *yaml_version_directive_t - - // The list of tag directives (for yaml_DOCUMENT_START_EVENT). - tag_directives []yaml_tag_directive_t - - // The anchor (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT, yaml_ALIAS_EVENT). - anchor []byte - - // The tag (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT). - tag []byte - - // The scalar value (for yaml_SCALAR_EVENT). - value []byte - - // Is the document start/end indicator implicit, or the tag optional? - // (for yaml_DOCUMENT_START_EVENT, yaml_DOCUMENT_END_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT, yaml_SCALAR_EVENT). - implicit bool - - // Is the tag optional for any non-plain style? (for yaml_SCALAR_EVENT). - quoted_implicit bool - - // The style (for yaml_SCALAR_EVENT, yaml_SEQUENCE_START_EVENT, yaml_MAPPING_START_EVENT). - style yaml_style_t -} - -func (e *yaml_event_t) scalar_style() yaml_scalar_style_t { return yaml_scalar_style_t(e.style) } -func (e *yaml_event_t) sequence_style() yaml_sequence_style_t { return yaml_sequence_style_t(e.style) } -func (e *yaml_event_t) mapping_style() yaml_mapping_style_t { return yaml_mapping_style_t(e.style) } - -// Nodes - -const ( - yaml_NULL_TAG = "tag:yaml.org,2002:null" // The tag !!null with the only possible value: null. - yaml_BOOL_TAG = "tag:yaml.org,2002:bool" // The tag !!bool with the values: true and false. - yaml_STR_TAG = "tag:yaml.org,2002:str" // The tag !!str for string values. - yaml_INT_TAG = "tag:yaml.org,2002:int" // The tag !!int for integer values. - yaml_FLOAT_TAG = "tag:yaml.org,2002:float" // The tag !!float for float values. - yaml_TIMESTAMP_TAG = "tag:yaml.org,2002:timestamp" // The tag !!timestamp for date and time values. - - yaml_SEQ_TAG = "tag:yaml.org,2002:seq" // The tag !!seq is used to denote sequences. - yaml_MAP_TAG = "tag:yaml.org,2002:map" // The tag !!map is used to denote mapping. - - // Not in original libyaml. - yaml_BINARY_TAG = "tag:yaml.org,2002:binary" - yaml_MERGE_TAG = "tag:yaml.org,2002:merge" - - yaml_DEFAULT_SCALAR_TAG = yaml_STR_TAG // The default scalar tag is !!str. - yaml_DEFAULT_SEQUENCE_TAG = yaml_SEQ_TAG // The default sequence tag is !!seq. - yaml_DEFAULT_MAPPING_TAG = yaml_MAP_TAG // The default mapping tag is !!map. -) - -type yaml_node_type_t int - -// Node types. -const ( - // An empty node. - yaml_NO_NODE yaml_node_type_t = iota - - yaml_SCALAR_NODE // A scalar node. - yaml_SEQUENCE_NODE // A sequence node. - yaml_MAPPING_NODE // A mapping node. -) - -// An element of a sequence node. -type yaml_node_item_t int - -// An element of a mapping node. -type yaml_node_pair_t struct { - key int // The key of the element. - value int // The value of the element. -} - -// The node structure. -type yaml_node_t struct { - typ yaml_node_type_t // The node type. - tag []byte // The node tag. - - // The node data. - - // The scalar parameters (for yaml_SCALAR_NODE). - scalar struct { - value []byte // The scalar value. - length int // The length of the scalar value. - style yaml_scalar_style_t // The scalar style. - } - - // The sequence parameters (for YAML_SEQUENCE_NODE). - sequence struct { - items_data []yaml_node_item_t // The stack of sequence items. - style yaml_sequence_style_t // The sequence style. - } - - // The mapping parameters (for yaml_MAPPING_NODE). - mapping struct { - pairs_data []yaml_node_pair_t // The stack of mapping pairs (key, value). - pairs_start *yaml_node_pair_t // The beginning of the stack. - pairs_end *yaml_node_pair_t // The end of the stack. - pairs_top *yaml_node_pair_t // The top of the stack. - style yaml_mapping_style_t // The mapping style. - } - - start_mark yaml_mark_t // The beginning of the node. - end_mark yaml_mark_t // The end of the node. - -} - -// The document structure. -type yaml_document_t struct { - - // The document nodes. - nodes []yaml_node_t - - // The version directive. - version_directive *yaml_version_directive_t - - // The list of tag directives. - tag_directives_data []yaml_tag_directive_t - tag_directives_start int // The beginning of the tag directives list. - tag_directives_end int // The end of the tag directives list. - - start_implicit int // Is the document start indicator implicit? - end_implicit int // Is the document end indicator implicit? - - // The start/end of the document. - start_mark, end_mark yaml_mark_t -} - -// The prototype of a read handler. -// -// The read handler is called when the parser needs to read more bytes from the -// source. The handler should write not more than size bytes to the buffer. -// The number of written bytes should be set to the size_read variable. -// -// [in,out] data A pointer to an application data specified by -// yaml_parser_set_input(). -// [out] buffer The buffer to write the data from the source. -// [in] size The size of the buffer. -// [out] size_read The actual number of bytes read from the source. -// -// On success, the handler should return 1. If the handler failed, -// the returned value should be 0. On EOF, the handler should set the -// size_read to 0 and return 1. -type yaml_read_handler_t func(parser *yaml_parser_t, buffer []byte) (n int, err error) - -// This structure holds information about a potential simple key. -type yaml_simple_key_t struct { - possible bool // Is a simple key possible? - required bool // Is a simple key required? - token_number int // The number of the token. - mark yaml_mark_t // The position mark. -} - -// The states of the parser. -type yaml_parser_state_t int - -const ( - yaml_PARSE_STREAM_START_STATE yaml_parser_state_t = iota - - yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE // Expect the beginning of an implicit document. - yaml_PARSE_DOCUMENT_START_STATE // Expect DOCUMENT-START. - yaml_PARSE_DOCUMENT_CONTENT_STATE // Expect the content of a document. - yaml_PARSE_DOCUMENT_END_STATE // Expect DOCUMENT-END. - yaml_PARSE_BLOCK_NODE_STATE // Expect a block node. - yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE // Expect a block node or indentless sequence. - yaml_PARSE_FLOW_NODE_STATE // Expect a flow node. - yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE // Expect the first entry of a block sequence. - yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE // Expect an entry of a block sequence. - yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE // Expect an entry of an indentless sequence. - yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE // Expect the first key of a block mapping. - yaml_PARSE_BLOCK_MAPPING_KEY_STATE // Expect a block mapping key. - yaml_PARSE_BLOCK_MAPPING_VALUE_STATE // Expect a block mapping value. - yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE // Expect the first entry of a flow sequence. - yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE // Expect an entry of a flow sequence. - yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE // Expect a key of an ordered mapping. - yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE // Expect a value of an ordered mapping. - yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE // Expect the and of an ordered mapping entry. - yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE // Expect the first key of a flow mapping. - yaml_PARSE_FLOW_MAPPING_KEY_STATE // Expect a key of a flow mapping. - yaml_PARSE_FLOW_MAPPING_VALUE_STATE // Expect a value of a flow mapping. - yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE // Expect an empty value of a flow mapping. - yaml_PARSE_END_STATE // Expect nothing. -) - -func (ps yaml_parser_state_t) String() string { - switch ps { - case yaml_PARSE_STREAM_START_STATE: - return "yaml_PARSE_STREAM_START_STATE" - case yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE: - return "yaml_PARSE_IMPLICIT_DOCUMENT_START_STATE" - case yaml_PARSE_DOCUMENT_START_STATE: - return "yaml_PARSE_DOCUMENT_START_STATE" - case yaml_PARSE_DOCUMENT_CONTENT_STATE: - return "yaml_PARSE_DOCUMENT_CONTENT_STATE" - case yaml_PARSE_DOCUMENT_END_STATE: - return "yaml_PARSE_DOCUMENT_END_STATE" - case yaml_PARSE_BLOCK_NODE_STATE: - return "yaml_PARSE_BLOCK_NODE_STATE" - case yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE: - return "yaml_PARSE_BLOCK_NODE_OR_INDENTLESS_SEQUENCE_STATE" - case yaml_PARSE_FLOW_NODE_STATE: - return "yaml_PARSE_FLOW_NODE_STATE" - case yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE: - return "yaml_PARSE_BLOCK_SEQUENCE_FIRST_ENTRY_STATE" - case yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE: - return "yaml_PARSE_BLOCK_SEQUENCE_ENTRY_STATE" - case yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE: - return "yaml_PARSE_INDENTLESS_SEQUENCE_ENTRY_STATE" - case yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE: - return "yaml_PARSE_BLOCK_MAPPING_FIRST_KEY_STATE" - case yaml_PARSE_BLOCK_MAPPING_KEY_STATE: - return "yaml_PARSE_BLOCK_MAPPING_KEY_STATE" - case yaml_PARSE_BLOCK_MAPPING_VALUE_STATE: - return "yaml_PARSE_BLOCK_MAPPING_VALUE_STATE" - case yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE: - return "yaml_PARSE_FLOW_SEQUENCE_FIRST_ENTRY_STATE" - case yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE: - return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_STATE" - case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE: - return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_KEY_STATE" - case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE: - return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_VALUE_STATE" - case yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE: - return "yaml_PARSE_FLOW_SEQUENCE_ENTRY_MAPPING_END_STATE" - case yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE: - return "yaml_PARSE_FLOW_MAPPING_FIRST_KEY_STATE" - case yaml_PARSE_FLOW_MAPPING_KEY_STATE: - return "yaml_PARSE_FLOW_MAPPING_KEY_STATE" - case yaml_PARSE_FLOW_MAPPING_VALUE_STATE: - return "yaml_PARSE_FLOW_MAPPING_VALUE_STATE" - case yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE: - return "yaml_PARSE_FLOW_MAPPING_EMPTY_VALUE_STATE" - case yaml_PARSE_END_STATE: - return "yaml_PARSE_END_STATE" - } - return "" -} - -// This structure holds aliases data. -type yaml_alias_data_t struct { - anchor []byte // The anchor. - index int // The node id. - mark yaml_mark_t // The anchor mark. -} - -// The parser structure. -// -// All members are internal. Manage the structure using the -// yaml_parser_ family of functions. -type yaml_parser_t struct { - - // Error handling - - error yaml_error_type_t // Error type. - - problem string // Error description. - - // The byte about which the problem occurred. - problem_offset int - problem_value int - problem_mark yaml_mark_t - - // The error context. - context string - context_mark yaml_mark_t - - // Reader stuff - - read_handler yaml_read_handler_t // Read handler. - - input_reader io.Reader // File input data. - input []byte // String input data. - input_pos int - - eof bool // EOF flag - - buffer []byte // The working buffer. - buffer_pos int // The current position of the buffer. - - unread int // The number of unread characters in the buffer. - - raw_buffer []byte // The raw buffer. - raw_buffer_pos int // The current position of the buffer. - - encoding yaml_encoding_t // The input encoding. - - offset int // The offset of the current position (in bytes). - mark yaml_mark_t // The mark of the current position. - - // Scanner stuff - - stream_start_produced bool // Have we started to scan the input stream? - stream_end_produced bool // Have we reached the end of the input stream? - - flow_level int // The number of unclosed '[' and '{' indicators. - - tokens []yaml_token_t // The tokens queue. - tokens_head int // The head of the tokens queue. - tokens_parsed int // The number of tokens fetched from the queue. - token_available bool // Does the tokens queue contain a token ready for dequeueing. - - indent int // The current indentation level. - indents []int // The indentation levels stack. - - simple_key_allowed bool // May a simple key occur at the current position? - simple_keys []yaml_simple_key_t // The stack of simple keys. - - // Parser stuff - - state yaml_parser_state_t // The current parser state. - states []yaml_parser_state_t // The parser states stack. - marks []yaml_mark_t // The stack of marks. - tag_directives []yaml_tag_directive_t // The list of TAG directives. - - // Dumper stuff - - aliases []yaml_alias_data_t // The alias data. - - document *yaml_document_t // The currently parsed document. -} - -// Emitter Definitions - -// The prototype of a write handler. -// -// The write handler is called when the emitter needs to flush the accumulated -// characters to the output. The handler should write @a size bytes of the -// @a buffer to the output. -// -// @param[in,out] data A pointer to an application data specified by -// yaml_emitter_set_output(). -// @param[in] buffer The buffer with bytes to be written. -// @param[in] size The size of the buffer. -// -// @returns On success, the handler should return @c 1. If the handler failed, -// the returned value should be @c 0. -// -type yaml_write_handler_t func(emitter *yaml_emitter_t, buffer []byte) error - -type yaml_emitter_state_t int - -// The emitter states. -const ( - // Expect STREAM-START. - yaml_EMIT_STREAM_START_STATE yaml_emitter_state_t = iota - - yaml_EMIT_FIRST_DOCUMENT_START_STATE // Expect the first DOCUMENT-START or STREAM-END. - yaml_EMIT_DOCUMENT_START_STATE // Expect DOCUMENT-START or STREAM-END. - yaml_EMIT_DOCUMENT_CONTENT_STATE // Expect the content of a document. - yaml_EMIT_DOCUMENT_END_STATE // Expect DOCUMENT-END. - yaml_EMIT_FLOW_SEQUENCE_FIRST_ITEM_STATE // Expect the first item of a flow sequence. - yaml_EMIT_FLOW_SEQUENCE_ITEM_STATE // Expect an item of a flow sequence. - yaml_EMIT_FLOW_MAPPING_FIRST_KEY_STATE // Expect the first key of a flow mapping. - yaml_EMIT_FLOW_MAPPING_KEY_STATE // Expect a key of a flow mapping. - yaml_EMIT_FLOW_MAPPING_SIMPLE_VALUE_STATE // Expect a value for a simple key of a flow mapping. - yaml_EMIT_FLOW_MAPPING_VALUE_STATE // Expect a value of a flow mapping. - yaml_EMIT_BLOCK_SEQUENCE_FIRST_ITEM_STATE // Expect the first item of a block sequence. - yaml_EMIT_BLOCK_SEQUENCE_ITEM_STATE // Expect an item of a block sequence. - yaml_EMIT_BLOCK_MAPPING_FIRST_KEY_STATE // Expect the first key of a block mapping. - yaml_EMIT_BLOCK_MAPPING_KEY_STATE // Expect the key of a block mapping. - yaml_EMIT_BLOCK_MAPPING_SIMPLE_VALUE_STATE // Expect a value for a simple key of a block mapping. - yaml_EMIT_BLOCK_MAPPING_VALUE_STATE // Expect a value of a block mapping. - yaml_EMIT_END_STATE // Expect nothing. -) - -// The emitter structure. -// -// All members are internal. Manage the structure using the @c yaml_emitter_ -// family of functions. -type yaml_emitter_t struct { - - // Error handling - - error yaml_error_type_t // Error type. - problem string // Error description. - - // Writer stuff - - write_handler yaml_write_handler_t // Write handler. - - output_buffer *[]byte // String output data. - output_writer io.Writer // File output data. - - buffer []byte // The working buffer. - buffer_pos int // The current position of the buffer. - - raw_buffer []byte // The raw buffer. - raw_buffer_pos int // The current position of the buffer. - - encoding yaml_encoding_t // The stream encoding. - - // Emitter stuff - - canonical bool // If the output is in the canonical style? - best_indent int // The number of indentation spaces. - best_width int // The preferred width of the output lines. - unicode bool // Allow unescaped non-ASCII characters? - line_break yaml_break_t // The preferred line break. - - state yaml_emitter_state_t // The current emitter state. - states []yaml_emitter_state_t // The stack of states. - - events []yaml_event_t // The event queue. - events_head int // The head of the event queue. - - indents []int // The stack of indentation levels. - - tag_directives []yaml_tag_directive_t // The list of tag directives. - - indent int // The current indentation level. - - flow_level int // The current flow level. - - root_context bool // Is it the document root context? - sequence_context bool // Is it a sequence context? - mapping_context bool // Is it a mapping context? - simple_key_context bool // Is it a simple mapping key context? - - line int // The current line. - column int // The current column. - whitespace bool // If the last character was a whitespace? - indention bool // If the last character was an indentation character (' ', '-', '?', ':')? - open_ended bool // If an explicit document end is required? - - // Anchor analysis. - anchor_data struct { - anchor []byte // The anchor value. - alias bool // Is it an alias? - } - - // Tag analysis. - tag_data struct { - handle []byte // The tag handle. - suffix []byte // The tag suffix. - } - - // Scalar analysis. - scalar_data struct { - value []byte // The scalar value. - multiline bool // Does the scalar contain line breaks? - flow_plain_allowed bool // Can the scalar be expessed in the flow plain style? - block_plain_allowed bool // Can the scalar be expressed in the block plain style? - single_quoted_allowed bool // Can the scalar be expressed in the single quoted style? - block_allowed bool // Can the scalar be expressed in the literal or folded styles? - style yaml_scalar_style_t // The output style. - } - - // Dumper stuff - - opened bool // If the stream was already opened? - closed bool // If the stream was already closed? - - // The information associated with the document nodes. - anchors *struct { - references int // The number of references. - anchor int // The anchor id. - serialized bool // If the node has been emitted? - } - - last_anchor_id int // The last assigned anchor id. - - document *yaml_document_t // The currently emitted document. -} diff --git a/vendor/gopkg.in/yaml.v2/yamlprivateh.go b/vendor/gopkg.in/yaml.v2/yamlprivateh.go deleted file mode 100644 index 8110ce3c37..0000000000 --- a/vendor/gopkg.in/yaml.v2/yamlprivateh.go +++ /dev/null @@ -1,173 +0,0 @@ -package yaml - -const ( - // The size of the input raw buffer. - input_raw_buffer_size = 512 - - // The size of the input buffer. - // It should be possible to decode the whole raw buffer. - input_buffer_size = input_raw_buffer_size * 3 - - // The size of the output buffer. - output_buffer_size = 128 - - // The size of the output raw buffer. - // It should be possible to encode the whole output buffer. - output_raw_buffer_size = (output_buffer_size*2 + 2) - - // The size of other stacks and queues. - initial_stack_size = 16 - initial_queue_size = 16 - initial_string_size = 16 -) - -// Check if the character at the specified position is an alphabetical -// character, a digit, '_', or '-'. -func is_alpha(b []byte, i int) bool { - return b[i] >= '0' && b[i] <= '9' || b[i] >= 'A' && b[i] <= 'Z' || b[i] >= 'a' && b[i] <= 'z' || b[i] == '_' || b[i] == '-' -} - -// Check if the character at the specified position is a digit. -func is_digit(b []byte, i int) bool { - return b[i] >= '0' && b[i] <= '9' -} - -// Get the value of a digit. -func as_digit(b []byte, i int) int { - return int(b[i]) - '0' -} - -// Check if the character at the specified position is a hex-digit. -func is_hex(b []byte, i int) bool { - return b[i] >= '0' && b[i] <= '9' || b[i] >= 'A' && b[i] <= 'F' || b[i] >= 'a' && b[i] <= 'f' -} - -// Get the value of a hex-digit. -func as_hex(b []byte, i int) int { - bi := b[i] - if bi >= 'A' && bi <= 'F' { - return int(bi) - 'A' + 10 - } - if bi >= 'a' && bi <= 'f' { - return int(bi) - 'a' + 10 - } - return int(bi) - '0' -} - -// Check if the character is ASCII. -func is_ascii(b []byte, i int) bool { - return b[i] <= 0x7F -} - -// Check if the character at the start of the buffer can be printed unescaped. -func is_printable(b []byte, i int) bool { - return ((b[i] == 0x0A) || // . == #x0A - (b[i] >= 0x20 && b[i] <= 0x7E) || // #x20 <= . <= #x7E - (b[i] == 0xC2 && b[i+1] >= 0xA0) || // #0xA0 <= . <= #xD7FF - (b[i] > 0xC2 && b[i] < 0xED) || - (b[i] == 0xED && b[i+1] < 0xA0) || - (b[i] == 0xEE) || - (b[i] == 0xEF && // #xE000 <= . <= #xFFFD - !(b[i+1] == 0xBB && b[i+2] == 0xBF) && // && . != #xFEFF - !(b[i+1] == 0xBF && (b[i+2] == 0xBE || b[i+2] == 0xBF)))) -} - -// Check if the character at the specified position is NUL. -func is_z(b []byte, i int) bool { - return b[i] == 0x00 -} - -// Check if the beginning of the buffer is a BOM. -func is_bom(b []byte, i int) bool { - return b[0] == 0xEF && b[1] == 0xBB && b[2] == 0xBF -} - -// Check if the character at the specified position is space. -func is_space(b []byte, i int) bool { - return b[i] == ' ' -} - -// Check if the character at the specified position is tab. -func is_tab(b []byte, i int) bool { - return b[i] == '\t' -} - -// Check if the character at the specified position is blank (space or tab). -func is_blank(b []byte, i int) bool { - //return is_space(b, i) || is_tab(b, i) - return b[i] == ' ' || b[i] == '\t' -} - -// Check if the character at the specified position is a line break. -func is_break(b []byte, i int) bool { - return (b[i] == '\r' || // CR (#xD) - b[i] == '\n' || // LF (#xA) - b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85) - b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028) - b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9) // PS (#x2029) -} - -func is_crlf(b []byte, i int) bool { - return b[i] == '\r' && b[i+1] == '\n' -} - -// Check if the character is a line break or NUL. -func is_breakz(b []byte, i int) bool { - //return is_break(b, i) || is_z(b, i) - return ( // is_break: - b[i] == '\r' || // CR (#xD) - b[i] == '\n' || // LF (#xA) - b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85) - b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028) - b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029) - // is_z: - b[i] == 0) -} - -// Check if the character is a line break, space, or NUL. -func is_spacez(b []byte, i int) bool { - //return is_space(b, i) || is_breakz(b, i) - return ( // is_space: - b[i] == ' ' || - // is_breakz: - b[i] == '\r' || // CR (#xD) - b[i] == '\n' || // LF (#xA) - b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85) - b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028) - b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029) - b[i] == 0) -} - -// Check if the character is a line break, space, tab, or NUL. -func is_blankz(b []byte, i int) bool { - //return is_blank(b, i) || is_breakz(b, i) - return ( // is_blank: - b[i] == ' ' || b[i] == '\t' || - // is_breakz: - b[i] == '\r' || // CR (#xD) - b[i] == '\n' || // LF (#xA) - b[i] == 0xC2 && b[i+1] == 0x85 || // NEL (#x85) - b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA8 || // LS (#x2028) - b[i] == 0xE2 && b[i+1] == 0x80 && b[i+2] == 0xA9 || // PS (#x2029) - b[i] == 0) -} - -// Determine the width of the character. -func width(b byte) int { - // Don't replace these by a switch without first - // confirming that it is being inlined. - if b&0x80 == 0x00 { - return 1 - } - if b&0xE0 == 0xC0 { - return 2 - } - if b&0xF0 == 0xE0 { - return 3 - } - if b&0xF8 == 0xF0 { - return 4 - } - return 0 - -} diff --git a/vendor/vendor.json b/vendor/vendor.json deleted file mode 100644 index fa90039278..0000000000 --- a/vendor/vendor.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "comment": "", - "ignore": "test", - "package": [ - { - "checksumSHA1": "J6lNRPdrYhKft6S8x33K9brxyhE=", - "path": "golang.org/x/crypto/acme", - "revision": "c126467f60eb25f8f27e5a981f32a87e3965053f", - "revisionTime": "2018-07-23T15:26:11Z" - }, - { - "checksumSHA1": "EFjIi/zCZ1Cte0MQtyxGCTgSzk8=", - "path": "golang.org/x/crypto/acme/autocert", - "revision": "c126467f60eb25f8f27e5a981f32a87e3965053f", - "revisionTime": "2018-07-23T15:26:11Z" - }, - { - "checksumSHA1": "1MGpGDQqnUoRpv7VEcQrXOBydXE=", - "path": "golang.org/x/crypto/pbkdf2", - "revision": "c126467f60eb25f8f27e5a981f32a87e3965053f", - "revisionTime": "2018-07-23T15:26:11Z" - }, - { - "checksumSHA1": "ZSWoOPUNRr5+3dhkLK3C4cZAQPk=", - "path": "gopkg.in/yaml.v2", - "revision": "5420a8b6744d3b0345ab293f6fcba19c978f1183", - "revisionTime": "2018-03-28T19:50:20Z" - } - ], - "rootPath": "github.com/astaxie/beego" -}