Switch back go upstream bwmarrin/discordgo

Commit ffa9956c9b got merged in.
This commit is contained in:
Wim 2018-11-13 00:02:07 +01:00
parent e9419f10d3
commit f8dc24bc09
78 changed files with 4948 additions and 1252 deletions

View File

@ -11,7 +11,7 @@ import (
"github.com/42wim/matterbridge/bridge" "github.com/42wim/matterbridge/bridge"
"github.com/42wim/matterbridge/bridge/config" "github.com/42wim/matterbridge/bridge/config"
"github.com/42wim/matterbridge/bridge/helper" "github.com/42wim/matterbridge/bridge/helper"
"github.com/matterbridge/discordgo" "github.com/bwmarrin/discordgo"
) )
const MessageLength = 1950 const MessageLength = 1950

6
go.mod
View File

@ -6,7 +6,7 @@ require (
github.com/Philipp15b/go-steam v0.0.0-20161020161927-e0f3bb9566e3 github.com/Philipp15b/go-steam v0.0.0-20161020161927-e0f3bb9566e3
github.com/Sirupsen/logrus v1.0.6 // indirect github.com/Sirupsen/logrus v1.0.6 // indirect
github.com/alecthomas/log4go v0.0.0-20160307011253-e5dc62318d9b // indirect github.com/alecthomas/log4go v0.0.0-20160307011253-e5dc62318d9b // indirect
github.com/bwmarrin/discordgo v0.0.0-20180201002541-8d5ab59c63e5 // indirect github.com/bwmarrin/discordgo v0.19.0
github.com/davecgh/go-spew v1.1.0 // indirect github.com/davecgh/go-spew v1.1.0 // indirect
github.com/dfordsoft/golib v0.0.0-20180313113957-2ea3495aee1d github.com/dfordsoft/golib v0.0.0-20180313113957-2ea3495aee1d
github.com/dgrijalva/jwt-go v0.0.0-20170508165458-6c8dedd55f8a // indirect github.com/dgrijalva/jwt-go v0.0.0-20170508165458-6c8dedd55f8a // indirect
@ -16,7 +16,7 @@ require (
github.com/google/gops v0.0.0-20170319002943-62f833fc9f6c github.com/google/gops v0.0.0-20170319002943-62f833fc9f6c
github.com/gopherjs/gopherjs v0.0.0-20180628210949-0892b62f0d9f // indirect github.com/gopherjs/gopherjs v0.0.0-20180628210949-0892b62f0d9f // indirect
github.com/gorilla/schema v0.0.0-20170317173100-f3c80893412c github.com/gorilla/schema v0.0.0-20170317173100-f3c80893412c
github.com/gorilla/websocket v0.0.0-20170319172727-a91eba7f9777 github.com/gorilla/websocket v1.4.0
github.com/hashicorp/golang-lru v0.0.0-20160813221303-0a025b7e63ad github.com/hashicorp/golang-lru v0.0.0-20160813221303-0a025b7e63ad
github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb // indirect github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb // indirect
github.com/hpcloud/tail v1.0.0 // indirect github.com/hpcloud/tail v1.0.0 // indirect
@ -30,7 +30,6 @@ require (
github.com/lusis/go-slackbot v0.0.0-20180109053408-401027ccfef5 // indirect github.com/lusis/go-slackbot v0.0.0-20180109053408-401027ccfef5 // indirect
github.com/lusis/slack-test v0.0.0-20180109053238-3c758769bfa6 // indirect github.com/lusis/slack-test v0.0.0-20180109053238-3c758769bfa6 // indirect
github.com/magiconair/properties v0.0.0-20180217134545-2c9e95027885 // indirect github.com/magiconair/properties v0.0.0-20180217134545-2c9e95027885 // indirect
github.com/matterbridge/discordgo v0.0.0-20180806170629-ef40ff5ba64f
github.com/matterbridge/go-xmpp v0.0.0-20180529212104-cd19799fba91 github.com/matterbridge/go-xmpp v0.0.0-20180529212104-cd19799fba91
github.com/matterbridge/gomatrix v0.0.0-20171224233421-78ac6a1a0f5f github.com/matterbridge/gomatrix v0.0.0-20171224233421-78ac6a1a0f5f
github.com/matterbridge/gozulipbot v0.0.0-20180507190239-b6bb12d33544 github.com/matterbridge/gozulipbot v0.0.0-20180507190239-b6bb12d33544
@ -72,7 +71,6 @@ require (
github.com/valyala/fasttemplate v0.0.0-20170224212429-dcecefd839c4 // indirect github.com/valyala/fasttemplate v0.0.0-20170224212429-dcecefd839c4 // indirect
github.com/x-cray/logrus-prefixed-formatter v0.5.2 // indirect github.com/x-cray/logrus-prefixed-formatter v0.5.2 // indirect
github.com/zfjagann/golang-ring v0.0.0-20141111230621-17637388c9f6 github.com/zfjagann/golang-ring v0.0.0-20141111230621-17637388c9f6
golang.org/x/crypto v0.0.0-20180228161326-91a49db82a88 // indirect
golang.org/x/net v0.0.0-20180108090419-434ec0c7fe37 // indirect golang.org/x/net v0.0.0-20180108090419-434ec0c7fe37 // indirect
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
golang.org/x/sys v0.0.0-20171130163741-8b4580aae2a0 // indirect golang.org/x/sys v0.0.0-20171130163741-8b4580aae2a0 // indirect

34
go.sum
View File

@ -4,12 +4,11 @@ github.com/BurntSushi/toml v0.0.0-20170318202913-d94612f9fc14 h1:v/zr4ns/4sSahF9
github.com/BurntSushi/toml v0.0.0-20170318202913-d94612f9fc14/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.0.0-20170318202913-d94612f9fc14/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Philipp15b/go-steam v0.0.0-20161020161927-e0f3bb9566e3 h1:V4+1E1SRYUySqwOoI3ZphFADtabbF568zTHa5ix/zU0= github.com/Philipp15b/go-steam v0.0.0-20161020161927-e0f3bb9566e3 h1:V4+1E1SRYUySqwOoI3ZphFADtabbF568zTHa5ix/zU0=
github.com/Philipp15b/go-steam v0.0.0-20161020161927-e0f3bb9566e3/go.mod h1:HuVM+sZFzumUdKPWiz+IlCMb4RdsKdT3T+nQBKL+sYg= github.com/Philipp15b/go-steam v0.0.0-20161020161927-e0f3bb9566e3/go.mod h1:HuVM+sZFzumUdKPWiz+IlCMb4RdsKdT3T+nQBKL+sYg=
github.com/Sirupsen/logrus v1.0.6 h1:HCAGQRk48dRVPA5Y+Yh0qdCSTzPOyU1tBJ7Q9YzotII=
github.com/Sirupsen/logrus v1.0.6/go.mod h1:rmk17hk6i8ZSAJkSDa7nOxamrG+SP4P0mm+DAvExv4U= github.com/Sirupsen/logrus v1.0.6/go.mod h1:rmk17hk6i8ZSAJkSDa7nOxamrG+SP4P0mm+DAvExv4U=
github.com/alecthomas/log4go v0.0.0-20160307011253-e5dc62318d9b h1:1OpGXps6UOY5HtQaQcLowsV1qMWCNBzhFvK7q4fgXtc= github.com/alecthomas/log4go v0.0.0-20160307011253-e5dc62318d9b h1:1OpGXps6UOY5HtQaQcLowsV1qMWCNBzhFvK7q4fgXtc=
github.com/alecthomas/log4go v0.0.0-20160307011253-e5dc62318d9b/go.mod h1:iCVmQ9g4TfaRX5m5jq5sXY7RXYWPv9/PynM/GocbG3w= github.com/alecthomas/log4go v0.0.0-20160307011253-e5dc62318d9b/go.mod h1:iCVmQ9g4TfaRX5m5jq5sXY7RXYWPv9/PynM/GocbG3w=
github.com/bwmarrin/discordgo v0.0.0-20180201002541-8d5ab59c63e5 h1:M7u44DKGpA5goDIBf0zRMYhT1Sp2Rd7hiTzXfeuw1UY= github.com/bwmarrin/discordgo v0.19.0 h1:kMED/DB0NR1QhRcalb85w0Cu3Ep2OrGAqZH1R5awQiY=
github.com/bwmarrin/discordgo v0.0.0-20180201002541-8d5ab59c63e5/go.mod h1:5NIvFv5Z7HddYuXbuQegZ684DleQaCFqChP2iuBivJ8= github.com/bwmarrin/discordgo v0.19.0/go.mod h1:O9S4p+ofTFwB02em7jkpkV8M3R0/PUVOwN61zSZ0r4Q=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dfordsoft/golib v0.0.0-20180313113957-2ea3495aee1d h1:rONNnZDE5CYuaSFQk+gP4GEQTXEUcyQ5p6p/dgxIHas= github.com/dfordsoft/golib v0.0.0-20180313113957-2ea3495aee1d h1:rONNnZDE5CYuaSFQk+gP4GEQTXEUcyQ5p6p/dgxIHas=
@ -24,28 +23,23 @@ github.com/golang/protobuf v0.0.0-20170613224224-e325f446bebc h1:wdhDSKrkYy24mcf
github.com/golang/protobuf v0.0.0-20170613224224-e325f446bebc/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v0.0.0-20170613224224-e325f446bebc/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/gops v0.0.0-20170319002943-62f833fc9f6c h1:MrMA1vhRTNidtgENqmsmLOIUS6ixMBOU/g10rm7IUe8= github.com/google/gops v0.0.0-20170319002943-62f833fc9f6c h1:MrMA1vhRTNidtgENqmsmLOIUS6ixMBOU/g10rm7IUe8=
github.com/google/gops v0.0.0-20170319002943-62f833fc9f6c/go.mod h1:pMQgrscwEK/aUSW1IFSaBPbJX82FPHWaSoJw1axQfD0= github.com/google/gops v0.0.0-20170319002943-62f833fc9f6c/go.mod h1:pMQgrscwEK/aUSW1IFSaBPbJX82FPHWaSoJw1axQfD0=
github.com/gopherjs/gopherjs v0.0.0-20180628210949-0892b62f0d9f h1:FDM3EtwZLyhW48YRiyqjivNlNZjAObv4xt4NnJaU+NQ=
github.com/gopherjs/gopherjs v0.0.0-20180628210949-0892b62f0d9f/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20180628210949-0892b62f0d9f/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/schema v0.0.0-20170317173100-f3c80893412c h1:mORYpib1aLu3M2Oi50Z1pNTXuDJEHcoLb6oo6VdOutk= github.com/gorilla/schema v0.0.0-20170317173100-f3c80893412c h1:mORYpib1aLu3M2Oi50Z1pNTXuDJEHcoLb6oo6VdOutk=
github.com/gorilla/schema v0.0.0-20170317173100-f3c80893412c/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/schema v0.0.0-20170317173100-f3c80893412c/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
github.com/gorilla/websocket v0.0.0-20170319172727-a91eba7f9777 h1:JIM+OacoOJRU30xpjMf8sulYqjr0ViA3WDrTX6j/yDI= github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q=
github.com/gorilla/websocket v0.0.0-20170319172727-a91eba7f9777/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/hashicorp/golang-lru v0.0.0-20160813221303-0a025b7e63ad h1:eMxs9EL0PvIGS9TTtxg4R+JxuPGav82J8rA+GFnY7po= github.com/hashicorp/golang-lru v0.0.0-20160813221303-0a025b7e63ad h1:eMxs9EL0PvIGS9TTtxg4R+JxuPGav82J8rA+GFnY7po=
github.com/hashicorp/golang-lru v0.0.0-20160813221303-0a025b7e63ad/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.0.0-20160813221303-0a025b7e63ad/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb h1:1OvvPvZkn/yCQ3xBcM8y4020wdkMXPHLB4+NfoGWh4U= github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb h1:1OvvPvZkn/yCQ3xBcM8y4020wdkMXPHLB4+NfoGWh4U=
github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb/go.mod h1:oZtUIOe8dh44I2q6ScRibXws4Ajl+d+nod3AaR9vL5w= github.com/hashicorp/hcl v0.0.0-20171017181929-23c074d0eceb/go.mod h1:oZtUIOe8dh44I2q6ScRibXws4Ajl+d+nod3AaR9vL5w=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/jpillora/backoff v0.0.0-20170222002228-06c7a16c845d h1:ETeT81zgLgSNc4BWdDO2Fg9ekVItYErbNtE8mKD2pJA= github.com/jpillora/backoff v0.0.0-20170222002228-06c7a16c845d h1:ETeT81zgLgSNc4BWdDO2Fg9ekVItYErbNtE8mKD2pJA=
github.com/jpillora/backoff v0.0.0-20170222002228-06c7a16c845d/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0= github.com/jpillora/backoff v0.0.0-20170222002228-06c7a16c845d/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0=
github.com/jtolds/gls v4.2.1+incompatible h1:fSuqC+Gmlu6l/ZYAoZzx2pyucC8Xza35fpRVWLVmUEE=
github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/jtolds/gls v4.2.1+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/kardianos/osext v0.0.0-20170207191655-9b883c5eb462 h1:oSOOTPHkCzMeu1vJ0nHxg5+XZBdMMjNa+6NPnm8arok= github.com/kardianos/osext v0.0.0-20170207191655-9b883c5eb462 h1:oSOOTPHkCzMeu1vJ0nHxg5+XZBdMMjNa+6NPnm8arok=
github.com/kardianos/osext v0.0.0-20170207191655-9b883c5eb462/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/kardianos/osext v0.0.0-20170207191655-9b883c5eb462/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/labstack/echo v0.0.0-20180219162101-7eec915044a1 h1:cOIt0LZKdfeirAfTP4VtIJuWbjVTGtd1suuPXp/J+dE= github.com/labstack/echo v0.0.0-20180219162101-7eec915044a1 h1:cOIt0LZKdfeirAfTP4VtIJuWbjVTGtd1suuPXp/J+dE=
github.com/labstack/echo v0.0.0-20180219162101-7eec915044a1/go.mod h1:0INS7j/VjnFxD4E2wkz67b8cVwCLbBmJyDaka6Cmk1s= github.com/labstack/echo v0.0.0-20180219162101-7eec915044a1/go.mod h1:0INS7j/VjnFxD4E2wkz67b8cVwCLbBmJyDaka6Cmk1s=
@ -53,14 +47,10 @@ github.com/labstack/gommon v0.2.1 h1:C+I4NYknueQncqKYZQ34kHsLZJVeB5KwPUhnO0nmbpU
github.com/labstack/gommon v0.2.1/go.mod h1:/tj9csK2iPSBvn+3NLM9e52usepMtrd5ilFYA+wQNJ4= github.com/labstack/gommon v0.2.1/go.mod h1:/tj9csK2iPSBvn+3NLM9e52usepMtrd5ilFYA+wQNJ4=
github.com/lrstanley/girc v0.0.0-20180913221000-0fb5b684054e h1:RpktB2igr6nS1EN7bCvjldAEfngrM5GyAbmOa4/cafU= github.com/lrstanley/girc v0.0.0-20180913221000-0fb5b684054e h1:RpktB2igr6nS1EN7bCvjldAEfngrM5GyAbmOa4/cafU=
github.com/lrstanley/girc v0.0.0-20180913221000-0fb5b684054e/go.mod h1:7cRs1SIBfKQ7e3Tam6GKTILSNHzR862JD0JpINaZoJk= github.com/lrstanley/girc v0.0.0-20180913221000-0fb5b684054e/go.mod h1:7cRs1SIBfKQ7e3Tam6GKTILSNHzR862JD0JpINaZoJk=
github.com/lusis/go-slackbot v0.0.0-20180109053408-401027ccfef5 h1:AsEBgzv3DhuYHI/GiQh2HxvTP71HCCE9E/tzGUzGdtU=
github.com/lusis/go-slackbot v0.0.0-20180109053408-401027ccfef5/go.mod h1:c2mYKRyMb1BPkO5St0c/ps62L4S0W2NAkaTXj9qEI+0= github.com/lusis/go-slackbot v0.0.0-20180109053408-401027ccfef5/go.mod h1:c2mYKRyMb1BPkO5St0c/ps62L4S0W2NAkaTXj9qEI+0=
github.com/lusis/slack-test v0.0.0-20180109053238-3c758769bfa6 h1:iOAVXzZyXtW408TMYejlUPo6BIn92HmOacWtIfNyYns=
github.com/lusis/slack-test v0.0.0-20180109053238-3c758769bfa6/go.mod h1:sFlOUpQL1YcjhFVXhg1CG8ZASEs/Mf1oVb6H75JL/zg= github.com/lusis/slack-test v0.0.0-20180109053238-3c758769bfa6/go.mod h1:sFlOUpQL1YcjhFVXhg1CG8ZASEs/Mf1oVb6H75JL/zg=
github.com/magiconair/properties v0.0.0-20180217134545-2c9e95027885 h1:HWxJJvF+QceKcql4r9PC93NtMEgEBfBxlQrZPvbcQvs= github.com/magiconair/properties v0.0.0-20180217134545-2c9e95027885 h1:HWxJJvF+QceKcql4r9PC93NtMEgEBfBxlQrZPvbcQvs=
github.com/magiconair/properties v0.0.0-20180217134545-2c9e95027885/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v0.0.0-20180217134545-2c9e95027885/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/matterbridge/discordgo v0.0.0-20180806170629-ef40ff5ba64f h1:9IIOO9Aznn8zJx3nokZ4U6nfuzWw5xAlygPvuRZMisQ=
github.com/matterbridge/discordgo v0.0.0-20180806170629-ef40ff5ba64f/go.mod h1:5QtN542bJn9FunZqYlIbleNtToxfLCVV9pW7m7Q42Fc=
github.com/matterbridge/go-xmpp v0.0.0-20180529212104-cd19799fba91 h1:KzDEcy8eDbTx881giW8a6llsAck3e2bJvMyKvh1IK+k= github.com/matterbridge/go-xmpp v0.0.0-20180529212104-cd19799fba91 h1:KzDEcy8eDbTx881giW8a6llsAck3e2bJvMyKvh1IK+k=
github.com/matterbridge/go-xmpp v0.0.0-20180529212104-cd19799fba91/go.mod h1:ECDRehsR9TYTKCAsRS8/wLeOk6UUqDydw47ln7wG41Q= github.com/matterbridge/go-xmpp v0.0.0-20180529212104-cd19799fba91/go.mod h1:ECDRehsR9TYTKCAsRS8/wLeOk6UUqDydw47ln7wG41Q=
github.com/matterbridge/gomatrix v0.0.0-20171224233421-78ac6a1a0f5f h1:2eKh6Qi/sJ8bXvYMoyVfQxHgR8UcCDWjOmhV1oCstMU= github.com/matterbridge/gomatrix v0.0.0-20171224233421-78ac6a1a0f5f h1:2eKh6Qi/sJ8bXvYMoyVfQxHgR8UcCDWjOmhV1oCstMU=
@ -87,9 +77,7 @@ github.com/nicksnyder/go-i18n v1.4.0 h1:AgLl+Yq7kg5OYlzCgu9cKTZOyI4tD/NgukKqLqC8
github.com/nicksnyder/go-i18n v1.4.0/go.mod h1:HrK7VCrbOvQoUAQ7Vpy7i87N7JZZZ7R2xBGjv0j365Q= github.com/nicksnyder/go-i18n v1.4.0/go.mod h1:HrK7VCrbOvQoUAQ7Vpy7i87N7JZZZ7R2xBGjv0j365Q=
github.com/nlopes/slack v0.4.0 h1:OVnHm7lv5gGT5gkcHsZAyw++oHVFihbjWbL3UceUpiA= github.com/nlopes/slack v0.4.0 h1:OVnHm7lv5gGT5gkcHsZAyw++oHVFihbjWbL3UceUpiA=
github.com/nlopes/slack v0.4.0/go.mod h1:jVI4BBK3lSktibKahxBF74txcK2vyvkza1z/+rRnVAM= github.com/nlopes/slack v0.4.0/go.mod h1:jVI4BBK3lSktibKahxBF74txcK2vyvkza1z/+rRnVAM=
github.com/onsi/ginkgo v1.6.0 h1:Ix8l273rp3QzYgXSR+c8d1fTG7UPgYkOSELPhiY/YGw=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.1 h1:PZSj/UFNaVp3KxrzHOcS7oyuWA7LoOY/77yCTEFu21U=
github.com/onsi/gomega v1.4.1/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v1.4.1/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
github.com/paulrosania/go-charset v0.0.0-20151028000031-621bb39fcc83 h1:XQonH5Iv5rbyIkMJOQ4xKmKHQTh8viXtRSmep5Ca5I4= github.com/paulrosania/go-charset v0.0.0-20151028000031-621bb39fcc83 h1:XQonH5Iv5rbyIkMJOQ4xKmKHQTh8viXtRSmep5Ca5I4=
github.com/paulrosania/go-charset v0.0.0-20151028000031-621bb39fcc83/go.mod h1:YnNlZP7l4MhyGQ4CBRwv6ohZTPrUJJZtEv4ZgADkbs4= github.com/paulrosania/go-charset v0.0.0-20151028000031-621bb39fcc83/go.mod h1:YnNlZP7l4MhyGQ4CBRwv6ohZTPrUJJZtEv4ZgADkbs4=
@ -117,9 +105,7 @@ github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95 h1:
github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v0.0.0-20180213143110-8c0189d9f6bb h1:eKjx20EiekBRT2tjZ0XEdKpftfPJQwiavtFshwTyqf0= github.com/sirupsen/logrus v0.0.0-20180213143110-8c0189d9f6bb h1:eKjx20EiekBRT2tjZ0XEdKpftfPJQwiavtFshwTyqf0=
github.com/sirupsen/logrus v0.0.0-20180213143110-8c0189d9f6bb/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v0.0.0-20180213143110-8c0189d9f6bb/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc=
github.com/smartystreets/assertions v0.0.0-20180803164922-886ec427f6b9 h1:lXQ+j+KwZcbwrbgU0Rp4Eglg3EJLHbuZU3BbOqAGBmg=
github.com/smartystreets/assertions v0.0.0-20180803164922-886ec427f6b9/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180803164922-886ec427f6b9/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20180222194500-ef6db91d284a h1:JSvGDIbmil4Ui/dDdFBExb7/cmkNjyX5F97oglmvCDo=
github.com/smartystreets/goconvey v0.0.0-20180222194500-ef6db91d284a/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/smartystreets/goconvey v0.0.0-20180222194500-ef6db91d284a/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/spf13/afero v0.0.0-20180211162714-bbf41cb36dff h1:HLvGWId7M56TfuxTeZ6aoiTAcrWO5Mnq/ArwVRgV62I= github.com/spf13/afero v0.0.0-20180211162714-bbf41cb36dff h1:HLvGWId7M56TfuxTeZ6aoiTAcrWO5Mnq/ArwVRgV62I=
github.com/spf13/afero v0.0.0-20180211162714-bbf41cb36dff/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v0.0.0-20180211162714-bbf41cb36dff/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ=
@ -139,29 +125,21 @@ github.com/valyala/bytebufferpool v0.0.0-20160817181652-e746df99fe4a h1:AOcehBWp
github.com/valyala/bytebufferpool v0.0.0-20160817181652-e746df99fe4a/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v0.0.0-20160817181652-e746df99fe4a/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasttemplate v0.0.0-20170224212429-dcecefd839c4 h1:gKMu1Bf6QINDnvyZuTaACm9ofY+PRh+5vFz4oxBZeF8= github.com/valyala/fasttemplate v0.0.0-20170224212429-dcecefd839c4 h1:gKMu1Bf6QINDnvyZuTaACm9ofY+PRh+5vFz4oxBZeF8=
github.com/valyala/fasttemplate v0.0.0-20170224212429-dcecefd839c4/go.mod h1:50wTf68f99/Zt14pr046Tgt3Lp2vLyFZKzbFXTOabXw= github.com/valyala/fasttemplate v0.0.0-20170224212429-dcecefd839c4/go.mod h1:50wTf68f99/Zt14pr046Tgt3Lp2vLyFZKzbFXTOabXw=
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE=
github.com/zfjagann/golang-ring v0.0.0-20141111230621-17637388c9f6 h1:/WULP+6asFz569UbOwg87f3iDT7T+GF5/vjLmL51Pdk= github.com/zfjagann/golang-ring v0.0.0-20141111230621-17637388c9f6 h1:/WULP+6asFz569UbOwg87f3iDT7T+GF5/vjLmL51Pdk=
github.com/zfjagann/golang-ring v0.0.0-20141111230621-17637388c9f6/go.mod h1:0MsIttMJIF/8Y7x0XjonJP7K99t3sR6bjj4m5S4JmqU= github.com/zfjagann/golang-ring v0.0.0-20141111230621-17637388c9f6/go.mod h1:0MsIttMJIF/8Y7x0XjonJP7K99t3sR6bjj4m5S4JmqU=
golang.org/x/crypto v0.0.0-20180228161326-91a49db82a88 h1:jLkAo/qlT9whgCLYC5GAJ9kcKrv3Wj8VCc4N+KJ4wpw= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16 h1:y6ce7gCWtnH+m3dCjzQ1PCuwl28DDIc3VNnvY29DlIA=
golang.org/x/crypto v0.0.0-20180228161326-91a49db82a88/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/net v0.0.0-20180108090419-434ec0c7fe37 h1:BkNcmLtAVeWe9h5k0jt24CQgaG5vb4x/doFbAiEC/Ho=
golang.org/x/net v0.0.0-20180108090419-434ec0c7fe37/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180108090419-434ec0c7fe37/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20171130163741-8b4580aae2a0 h1:x4M4WCms+ErQg/4VyECbP2kSNcDJ6nLwqEGov1QPtqk= golang.org/x/sys v0.0.0-20171130163741-8b4580aae2a0 h1:x4M4WCms+ErQg/4VyECbP2kSNcDJ6nLwqEGov1QPtqk=
golang.org/x/sys v0.0.0-20171130163741-8b4580aae2a0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20171130163741-8b4580aae2a0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.0.0-20180511172408-5c1cf69b5978 h1:WNm0tmiuBMW4FJRuXKWOqaQfmKptHs0n8nTCyG0ayjc= golang.org/x/text v0.0.0-20180511172408-5c1cf69b5978 h1:WNm0tmiuBMW4FJRuXKWOqaQfmKptHs0n8nTCyG0ayjc=
golang.org/x/text v0.0.0-20180511172408-5c1cf69b5978/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20180511172408-5c1cf69b5978/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/airbrake/gobrake.v2 v2.0.9 h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo=
gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.1 h1:4buh9nXkpqc7+GLzDFHei0jwoU9wCQYfVB5Kfo58Yz0=
gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.1/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo= gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.1/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.0.0-20160301204022-a83829b6f129 h1:RBgb9aPUbZ9nu66ecQNIBNsA7j3mB5h8PNDIfhPjaJg= gopkg.in/yaml.v2 v2.0.0-20160301204022-a83829b6f129 h1:RBgb9aPUbZ9nu66ecQNIBNsA7j3mB5h8PNDIfhPjaJg=
gopkg.in/yaml.v2 v2.0.0-20160301204022-a83829b6f129/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.0.0-20160301204022-a83829b6f129/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=

View File

@ -1,12 +1,12 @@
language: go language: go
go: go:
- 1.7.x
- 1.8.x
- 1.9.x - 1.9.x
- 1.10.x
- 1.11.x
install: install:
- go get github.com/bwmarrin/discordgo - go get github.com/bwmarrin/discordgo
- go get -v . - go get -v .
- go get -v github.com/golang/lint/golint - go get -v golang.org/x/lint/golint
script: script:
- diff <(gofmt -d .) <(echo -n) - diff <(gofmt -d .) <(echo -n)
- go vet -x ./... - go vet -x ./...

View File

@ -15,11 +15,11 @@ to add the official DiscordGo test bot **dgo** to your server. This provides
indispensable help to this project. indispensable help to this project.
* See [dgVoice](https://github.com/bwmarrin/dgvoice) package for an example of * See [dgVoice](https://github.com/bwmarrin/dgvoice) package for an example of
additional voice helper functions and features for DiscordGo additional voice helper functions and features for DiscordGo.
* See [dca](https://github.com/bwmarrin/dca) for an **experimental** stand alone * See [dca](https://github.com/bwmarrin/dca) for an **experimental** stand alone
tool that wraps `ffmpeg` to create opus encoded audio appropriate for use with tool that wraps `ffmpeg` to create opus encoded audio appropriate for use with
Discord (and DiscordGo) Discord (and DiscordGo).
**For help with this package or general Go discussion, please join the [Discord **For help with this package or general Go discussion, please join the [Discord
Gophers](https://discord.gg/0f1SbxBZjYq9jLBk) chat server.** Gophers](https://discord.gg/0f1SbxBZjYq9jLBk) chat server.**
@ -39,9 +39,9 @@ the breaking changes get documented before pushing to master.
*So, what should you use?* *So, what should you use?*
If you can accept the constant changing nature of *develop* then it is the If you can accept the constant changing nature of *develop*, it is the
recommended branch to use. Otherwise, if you want to tail behind development recommended branch to use. Otherwise, if you want to tail behind development
slightly and have a more stable package with documented releases then use *master* slightly and have a more stable package with documented releases, use *master*.
### Installing ### Installing
@ -96,10 +96,10 @@ that information in a nice format.
## Examples ## Examples
Below is a list of examples and other projects using DiscordGo. Please submit Below is a list of examples and other projects using DiscordGo. Please submit
an issue if you would like your project added or removed from this list an issue if you would like your project added or removed from this list.
- [DiscordGo Examples](https://github.com/bwmarrin/discordgo/tree/master/examples) A collection of example programs written with DiscordGo - [DiscordGo Examples](https://github.com/bwmarrin/discordgo/tree/master/examples) - A collection of example programs written with DiscordGo
- [Awesome DiscordGo](https://github.com/bwmarrin/discordgo/wiki/Awesome-DiscordGo) A curated list of high quality projects using DiscordGo - [Awesome DiscordGo](https://github.com/bwmarrin/discordgo/wiki/Awesome-DiscordGo) - A curated list of high quality projects using DiscordGo
## Troubleshooting ## Troubleshooting
For help with common problems please reference the For help with common problems please reference the
@ -114,7 +114,7 @@ Contributions are very welcomed, however please follow the below guidelines.
discussed. discussed.
- Fork the develop branch and make your changes. - Fork the develop branch and make your changes.
- Try to match current naming conventions as closely as possible. - Try to match current naming conventions as closely as possible.
- This package is intended to be a low level direct mapping of the Discord API - This package is intended to be a low level direct mapping of the Discord API,
so please avoid adding enhancements outside of that scope without first so please avoid adding enhancements outside of that scope without first
discussing it. discussing it.
- Create a Pull Request with your changes against the develop branch. - Create a Pull Request with your changes against the develop branch.
@ -127,4 +127,4 @@ comparison and list of other Discord API libraries.
## Special Thanks ## Special Thanks
[Chris Rhodes](https://github.com/iopred) - For the DiscordGo logo and tons of PRs [Chris Rhodes](https://github.com/iopred) - For the DiscordGo logo and tons of PRs.

View File

@ -6,8 +6,8 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// This file contains high level helper functions and easy entry points for the // This file contains high level helper functions and easy entry points for the
// entire discordgo package. These functions are beling developed and are very // entire discordgo package. These functions are being developed and are very
// experimental at this point. They will most likley change so please use the // experimental at this point. They will most likely change so please use the
// low level functions if that's a problem. // low level functions if that's a problem.
// Package discordgo provides Discord binding for Go // Package discordgo provides Discord binding for Go
@ -21,7 +21,7 @@ import (
) )
// VERSION of DiscordGo, follows Semantic Versioning. (http://semver.org/) // VERSION of DiscordGo, follows Semantic Versioning. (http://semver.org/)
const VERSION = "0.18.0" const VERSION = "0.19.0"
// ErrMFA will be risen by New when the user has 2FA. // ErrMFA will be risen by New when the user has 2FA.
var ErrMFA = errors.New("account has 2FA enabled") var ErrMFA = errors.New("account has 2FA enabled")

View File

@ -11,6 +11,8 @@
package discordgo package discordgo
import "strconv"
// APIVersion is the Discord API version used for the REST and Websocket API. // APIVersion is the Discord API version used for the REST and Websocket API.
var APIVersion = "6" var APIVersion = "6"
@ -61,6 +63,10 @@ var (
EndpointUser = func(uID string) string { return EndpointUsers + uID } EndpointUser = func(uID string) string { return EndpointUsers + uID }
EndpointUserAvatar = func(uID, aID string) string { return EndpointCDNAvatars + uID + "/" + aID + ".png" } EndpointUserAvatar = func(uID, aID string) string { return EndpointCDNAvatars + uID + "/" + aID + ".png" }
EndpointUserAvatarAnimated = func(uID, aID string) string { return EndpointCDNAvatars + uID + "/" + aID + ".gif" } EndpointUserAvatarAnimated = func(uID, aID string) string { return EndpointCDNAvatars + uID + "/" + aID + ".gif" }
EndpointDefaultUserAvatar = func(uDiscriminator string) string {
uDiscriminatorInt, _ := strconv.Atoi(uDiscriminator)
return EndpointCDN + "embed/avatars/" + strconv.Itoa(uDiscriminatorInt%5) + ".png"
}
EndpointUserSettings = func(uID string) string { return EndpointUsers + uID + "/settings" } EndpointUserSettings = func(uID string) string { return EndpointUsers + uID + "/settings" }
EndpointUserGuilds = func(uID string) string { return EndpointUsers + uID + "/guilds" } EndpointUserGuilds = func(uID string) string { return EndpointUsers + uID + "/guilds" }
EndpointUserGuild = func(uID, gID string) string { return EndpointUsers + uID + "/guilds/" + gID } EndpointUserGuild = func(uID, gID string) string { return EndpointUsers + uID + "/guilds/" + gID }
@ -88,6 +94,9 @@ var (
EndpointGuildIcon = func(gID, hash string) string { return EndpointCDNIcons + gID + "/" + hash + ".png" } EndpointGuildIcon = func(gID, hash string) string { return EndpointCDNIcons + gID + "/" + hash + ".png" }
EndpointGuildSplash = func(gID, hash string) string { return EndpointCDNSplashes + gID + "/" + hash + ".png" } EndpointGuildSplash = func(gID, hash string) string { return EndpointCDNSplashes + gID + "/" + hash + ".png" }
EndpointGuildWebhooks = func(gID string) string { return EndpointGuilds + gID + "/webhooks" } EndpointGuildWebhooks = func(gID string) string { return EndpointGuilds + gID + "/webhooks" }
EndpointGuildAuditLogs = func(gID string) string { return EndpointGuilds + gID + "/audit-logs" }
EndpointGuildEmojis = func(gID string) string { return EndpointGuilds + gID + "/emojis" }
EndpointGuildEmoji = func(gID, eID string) string { return EndpointGuilds + gID + "/emojis/" + eID }
EndpointChannel = func(cID string) string { return EndpointChannels + cID } EndpointChannel = func(cID string) string { return EndpointChannels + cID }
EndpointChannelPermissions = func(cID string) string { return EndpointChannels + cID + "/permissions" } EndpointChannelPermissions = func(cID string) string { return EndpointChannels + cID + "/permissions" }
@ -128,6 +137,7 @@ var (
EndpointIntegrationsJoin = func(iID string) string { return EndpointAPI + "integrations/" + iID + "/join" } EndpointIntegrationsJoin = func(iID string) string { return EndpointAPI + "integrations/" + iID + "/join" }
EndpointEmoji = func(eID string) string { return EndpointAPI + "emojis/" + eID + ".png" } EndpointEmoji = func(eID string) string { return EndpointAPI + "emojis/" + eID + ".png" }
EndpointEmojiAnimated = func(eID string) string { return EndpointAPI + "emojis/" + eID + ".gif" }
EndpointOauth2 = EndpointAPI + "oauth2/" EndpointOauth2 = EndpointAPI + "oauth2/"
EndpointApplications = EndpointOauth2 + "applications" EndpointApplications = EndpointOauth2 + "applications"

View File

@ -98,7 +98,9 @@ func (s *Session) addEventHandlerOnce(eventHandler EventHandler) func() {
// AddHandler allows you to add an event handler that will be fired anytime // AddHandler allows you to add an event handler that will be fired anytime
// the Discord WSAPI event that matches the function fires. // the Discord WSAPI event that matches the function fires.
// events.go contains all the Discord WSAPI events that can be fired. // The first parameter is a *Session, and the second parameter is a pointer
// to a struct corresponding to the event for which you want to listen.
//
// eg: // eg:
// Session.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { // Session.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
// }) // })
@ -106,6 +108,13 @@ func (s *Session) addEventHandlerOnce(eventHandler EventHandler) func() {
// or: // or:
// Session.AddHandler(func(s *discordgo.Session, m *discordgo.PresenceUpdate) { // Session.AddHandler(func(s *discordgo.Session, m *discordgo.PresenceUpdate) {
// }) // })
//
// List of events can be found at this page, with corresponding names in the
// library for each event: https://discordapp.com/developers/docs/topics/gateway#event-names
// There are also synthetic events fired by the library internally which are
// available for handling, like Connect, Disconnect, and RateLimit.
// events.go contains all of the Discord WSAPI and synthetic events that can be handled.
//
// The return value of this method is a function, that when called will remove the // The return value of this method is a function, that when called will remove the
// event handler. // event handler.
func (s *Session) AddHandler(handler interface{}) func() { func (s *Session) AddHandler(handler interface{}) func() {

View File

@ -50,6 +50,7 @@ const (
userUpdateEventType = "USER_UPDATE" userUpdateEventType = "USER_UPDATE"
voiceServerUpdateEventType = "VOICE_SERVER_UPDATE" voiceServerUpdateEventType = "VOICE_SERVER_UPDATE"
voiceStateUpdateEventType = "VOICE_STATE_UPDATE" voiceStateUpdateEventType = "VOICE_STATE_UPDATE"
webhooksUpdateEventType = "WEBHOOKS_UPDATE"
) )
// channelCreateEventHandler is an event handler for ChannelCreate events. // channelCreateEventHandler is an event handler for ChannelCreate events.
@ -892,6 +893,26 @@ func (eh voiceStateUpdateEventHandler) Handle(s *Session, i interface{}) {
} }
} }
// webhooksUpdateEventHandler is an event handler for WebhooksUpdate events.
type webhooksUpdateEventHandler func(*Session, *WebhooksUpdate)
// Type returns the event type for WebhooksUpdate events.
func (eh webhooksUpdateEventHandler) Type() string {
return webhooksUpdateEventType
}
// New returns a new instance of WebhooksUpdate.
func (eh webhooksUpdateEventHandler) New() interface{} {
return &WebhooksUpdate{}
}
// Handle is the handler for WebhooksUpdate events.
func (eh webhooksUpdateEventHandler) Handle(s *Session, i interface{}) {
if t, ok := i.(*WebhooksUpdate); ok {
eh(s, t)
}
}
func handlerForInterface(handler interface{}) EventHandler { func handlerForInterface(handler interface{}) EventHandler {
switch v := handler.(type) { switch v := handler.(type) {
case func(*Session, interface{}): case func(*Session, interface{}):
@ -982,6 +1003,8 @@ func handlerForInterface(handler interface{}) EventHandler {
return voiceServerUpdateEventHandler(v) return voiceServerUpdateEventHandler(v)
case func(*Session, *VoiceStateUpdate): case func(*Session, *VoiceStateUpdate):
return voiceStateUpdateEventHandler(v) return voiceStateUpdateEventHandler(v)
case func(*Session, *WebhooksUpdate):
return webhooksUpdateEventHandler(v)
} }
return nil return nil
@ -1027,4 +1050,5 @@ func init() {
registerInterfaceProvider(userUpdateEventHandler(nil)) registerInterfaceProvider(userUpdateEventHandler(nil))
registerInterfaceProvider(voiceServerUpdateEventHandler(nil)) registerInterfaceProvider(voiceServerUpdateEventHandler(nil))
registerInterfaceProvider(voiceStateUpdateEventHandler(nil)) registerInterfaceProvider(voiceStateUpdateEventHandler(nil))
registerInterfaceProvider(webhooksUpdateEventHandler(nil))
} }

View File

@ -70,6 +70,7 @@ type ChannelDelete struct {
type ChannelPinsUpdate struct { type ChannelPinsUpdate struct {
LastPinTimestamp string `json:"last_pin_timestamp"` LastPinTimestamp string `json:"last_pin_timestamp"`
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
GuildID string `json:"guild_id,omitempty"`
} }
// GuildCreate is the data for a GuildCreate event. // GuildCreate is the data for a GuildCreate event.
@ -212,6 +213,7 @@ type RelationshipRemove struct {
type TypingStart struct { type TypingStart struct {
UserID string `json:"user_id"` UserID string `json:"user_id"`
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
GuildID string `json:"guild_id,omitempty"`
Timestamp int `json:"timestamp"` Timestamp int `json:"timestamp"`
} }
@ -250,4 +252,11 @@ type VoiceStateUpdate struct {
type MessageDeleteBulk struct { type MessageDeleteBulk struct {
Messages []string `json:"ids"` Messages []string `json:"ids"`
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
GuildID string `json:"guild_id"`
}
// WebhooksUpdate is the data for a WebhooksUpdate event
type WebhooksUpdate struct {
GuildID string `json:"guild_id"`
ChannelID string `json:"channel_id"`
} }

6
vendor/github.com/bwmarrin/discordgo/go.mod generated vendored Normal file
View File

@ -0,0 +1,6 @@
module github.com/bwmarrin/discordgo
require (
github.com/gorilla/websocket v1.4.0
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16
)

4
vendor/github.com/bwmarrin/discordgo/go.sum generated vendored Normal file
View File

@ -0,0 +1,4 @@
github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16 h1:y6ce7gCWtnH+m3dCjzQ1PCuwl28DDIc3VNnvY29DlIA=
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=

View File

@ -32,20 +32,59 @@ const (
// A Message stores all data related to a specific Discord message. // A Message stores all data related to a specific Discord message.
type Message struct { type Message struct {
// The ID of the message.
ID string `json:"id"` ID string `json:"id"`
// The ID of the channel in which the message was sent.
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
// The ID of the guild in which the message was sent.
GuildID string `json:"guild_id,omitempty"`
// The content of the message.
Content string `json:"content"` Content string `json:"content"`
// The time at which the messsage was sent.
// CAUTION: this field may be removed in a
// future API version; it is safer to calculate
// the creation time via the ID.
Timestamp Timestamp `json:"timestamp"` Timestamp Timestamp `json:"timestamp"`
// The time at which the last edit of the message
// occurred, if it has been edited.
EditedTimestamp Timestamp `json:"edited_timestamp"` EditedTimestamp Timestamp `json:"edited_timestamp"`
// The roles mentioned in the message.
MentionRoles []string `json:"mention_roles"` MentionRoles []string `json:"mention_roles"`
// Whether the message is text-to-speech.
Tts bool `json:"tts"` Tts bool `json:"tts"`
// Whether the message mentions everyone.
MentionEveryone bool `json:"mention_everyone"` MentionEveryone bool `json:"mention_everyone"`
// The author of the message. This is not guaranteed to be a
// valid user (webhook-sent messages do not possess a full author).
Author *User `json:"author"` Author *User `json:"author"`
// A list of attachments present in the message.
Attachments []*MessageAttachment `json:"attachments"` Attachments []*MessageAttachment `json:"attachments"`
// A list of embeds present in the message. Multiple
// embeds can currently only be sent by webhooks.
Embeds []*MessageEmbed `json:"embeds"` Embeds []*MessageEmbed `json:"embeds"`
// A list of users mentioned in the message.
Mentions []*User `json:"mentions"` Mentions []*User `json:"mentions"`
// A list of reactions to the message.
Reactions []*MessageReactions `json:"reactions"` Reactions []*MessageReactions `json:"reactions"`
// The type of the message.
Type MessageType `json:"type"` Type MessageType `json:"type"`
// The webhook ID of the message, if it was generated by a webhook
WebhookID string `json:"webhook_id"`
} }
// File stores info about files you e.g. send in messages. // File stores info about files you e.g. send in messages.

View File

@ -38,6 +38,7 @@ var (
ErrPruneDaysBounds = errors.New("the number of days should be more than or equal to 1") ErrPruneDaysBounds = errors.New("the number of days should be more than or equal to 1")
ErrGuildNoIcon = errors.New("guild does not have an icon set") ErrGuildNoIcon = errors.New("guild does not have an icon set")
ErrGuildNoSplash = errors.New("guild does not have a splash set") ErrGuildNoSplash = errors.New("guild does not have a splash set")
ErrUnauthorized = errors.New("HTTP request was unauthorized. This could be because the provided token was not a bot token. Please add \"Bot \" to the start of your token. https://discordapp.com/developers/docs/reference#authentication-example-bot-token-authorization-header")
) )
// Request is the same as RequestWithBucketID but the bucket id is the same as the urlStr // Request is the same as RequestWithBucketID but the bucket id is the same as the urlStr
@ -89,7 +90,7 @@ func (s *Session) RequestWithLockedBucket(method, urlStr, contentType string, b
req.Header.Set("Content-Type", contentType) req.Header.Set("Content-Type", contentType)
// TODO: Make a configurable static variable. // TODO: Make a configurable static variable.
req.Header.Set("User-Agent", fmt.Sprintf("DiscordBot (https://github.com/bwmarrin/discordgo, v%s)", VERSION)) req.Header.Set("User-Agent", "DiscordBot (https://github.com/bwmarrin/discordgo, v"+VERSION+")")
if s.Debug { if s.Debug {
for k, v := range req.Header { for k, v := range req.Header {
@ -129,13 +130,9 @@ func (s *Session) RequestWithLockedBucket(method, urlStr, contentType string, b
} }
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusOK: case http.StatusOK:
case http.StatusCreated: case http.StatusCreated:
case http.StatusNoContent: case http.StatusNoContent:
// TODO check for 401 response, invalidate token if we get one.
case http.StatusBadGateway: case http.StatusBadGateway:
// Retry sending request if possible // Retry sending request if possible
if sequence < s.MaxRestRetries { if sequence < s.MaxRestRetries {
@ -145,7 +142,6 @@ func (s *Session) RequestWithLockedBucket(method, urlStr, contentType string, b
} else { } else {
err = fmt.Errorf("Exceeded Max retries HTTP %s, %s", resp.Status, response) err = fmt.Errorf("Exceeded Max retries HTTP %s, %s", resp.Status, response)
} }
case 429: // TOO MANY REQUESTS - Rate limiting case 429: // TOO MANY REQUESTS - Rate limiting
rl := TooManyRequests{} rl := TooManyRequests{}
err = json.Unmarshal(response, &rl) err = json.Unmarshal(response, &rl)
@ -161,7 +157,12 @@ func (s *Session) RequestWithLockedBucket(method, urlStr, contentType string, b
// this method can cause longer delays than required // this method can cause longer delays than required
response, err = s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucketObject(bucket), sequence) response, err = s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucketObject(bucket), sequence)
case http.StatusUnauthorized:
if strings.Index(s.Token, "Bot ") != 0 {
s.log(LogInformational, ErrUnauthorized.Error())
err = ErrUnauthorized
}
fallthrough
default: // Error condition default: // Error condition
err = newRestError(req, resp, response) err = newRestError(req, resp, response)
} }
@ -249,7 +250,7 @@ func (s *Session) Register(username string) (token string, err error) {
// even use. // even use.
func (s *Session) Logout() (err error) { func (s *Session) Logout() (err error) {
// _, err = s.Request("POST", LOGOUT, fmt.Sprintf(`{"token": "%s"}`, s.Token)) // _, err = s.Request("POST", LOGOUT, `{"token": "` + s.Token + `"}`)
if s.Token == "" { if s.Token == "" {
return return
@ -361,6 +362,21 @@ func (s *Session) UserUpdateStatus(status Status) (st *Settings, err error) {
return return
} }
// UserConnections returns the user's connections
func (s *Session) UserConnections() (conn []*UserConnection, err error) {
response, err := s.RequestWithBucketID("GET", EndpointUserConnections("@me"), nil, EndpointUserConnections("@me"))
if err != nil {
return nil, err
}
err = unmarshal(response, &conn)
if err != nil {
return
}
return
}
// UserChannels returns an array of Channel structures for all private // UserChannels returns an array of Channel structures for all private
// channels. // channels.
func (s *Session) UserChannels() (st []*Channel, err error) { func (s *Session) UserChannels() (st []*Channel, err error) {
@ -412,7 +428,7 @@ func (s *Session) UserGuilds(limit int, beforeID, afterID string) (st []*UserGui
uri := EndpointUserGuilds("@me") uri := EndpointUserGuilds("@me")
if len(v) > 0 { if len(v) > 0 {
uri = fmt.Sprintf("%s?%s", uri, v.Encode()) uri += "?" + v.Encode()
} }
body, err := s.RequestWithBucketID("GET", uri, nil, EndpointUserGuilds("")) body, err := s.RequestWithBucketID("GET", uri, nil, EndpointUserGuilds(""))
@ -565,7 +581,7 @@ func (s *Session) Guild(guildID string) (st *Guild, err error) {
if s.StateEnabled { if s.StateEnabled {
// Attempt to grab the guild from State first. // Attempt to grab the guild from State first.
st, err = s.State.Guild(guildID) st, err = s.State.Guild(guildID)
if err == nil { if err == nil && !st.Unavailable {
return return
} }
} }
@ -735,7 +751,7 @@ func (s *Session) GuildMembers(guildID string, after string, limit int) (st []*M
} }
if len(v) > 0 { if len(v) > 0 {
uri = fmt.Sprintf("%s?%s", uri, v.Encode()) uri += "?" + v.Encode()
} }
body, err := s.RequestWithBucketID("GET", uri, nil, EndpointGuildMembers(guildID)) body, err := s.RequestWithBucketID("GET", uri, nil, EndpointGuildMembers(guildID))
@ -761,6 +777,32 @@ func (s *Session) GuildMember(guildID, userID string) (st *Member, err error) {
return return
} }
// GuildMemberAdd force joins a user to the guild.
// accessToken : Valid access_token for the user.
// guildID : The ID of a Guild.
// userID : The ID of a User.
// nick : Value to set users nickname to
// roles : A list of role ID's to set on the member.
// mute : If the user is muted.
// deaf : If the user is deafened.
func (s *Session) GuildMemberAdd(accessToken, guildID, userID, nick string, roles []string, mute, deaf bool) (err error) {
data := struct {
AccessToken string `json:"access_token"`
Nick string `json:"nick,omitempty"`
Roles []string `json:"roles,omitempty"`
Mute bool `json:"mute,omitempty"`
Deaf bool `json:"deaf,omitempty"`
}{accessToken, nick, roles, mute, deaf}
_, err = s.RequestWithBucketID("PUT", EndpointGuildMember(guildID, userID), data, EndpointGuildMember(guildID, ""))
if err != nil {
return err
}
return err
}
// GuildMemberDelete removes the given user from the given guild. // GuildMemberDelete removes the given user from the given guild.
// guildID : The ID of a Guild. // guildID : The ID of a Guild.
// userID : The ID of a User // userID : The ID of a User
@ -877,17 +919,22 @@ func (s *Session) GuildChannels(guildID string) (st []*Channel, err error) {
return return
} }
// GuildChannelCreate creates a new channel in the given guild // GuildChannelCreateData is provided to GuildChannelCreateComplex
// guildID : The ID of a Guild. type GuildChannelCreateData struct {
// name : Name of the channel (2-100 chars length)
// ctype : Tpye of the channel (voice or text)
func (s *Session) GuildChannelCreate(guildID, name, ctype string) (st *Channel, err error) {
data := struct {
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` Type ChannelType `json:"type"`
}{name, ctype} Topic string `json:"topic,omitempty"`
Bitrate int `json:"bitrate,omitempty"`
UserLimit int `json:"user_limit,omitempty"`
PermissionOverwrites []*PermissionOverwrite `json:"permission_overwrites,omitempty"`
ParentID string `json:"parent_id,omitempty"`
NSFW bool `json:"nsfw,omitempty"`
}
// GuildChannelCreateComplex creates a new channel in the given guild
// guildID : The ID of a Guild
// data : A data struct describing the new Channel, Name and Type are mandatory, other fields depending on the type
func (s *Session) GuildChannelCreateComplex(guildID string, data GuildChannelCreateData) (st *Channel, err error) {
body, err := s.RequestWithBucketID("POST", EndpointGuildChannels(guildID), data, EndpointGuildChannels(guildID)) body, err := s.RequestWithBucketID("POST", EndpointGuildChannels(guildID), data, EndpointGuildChannels(guildID))
if err != nil { if err != nil {
return return
@ -897,12 +944,33 @@ func (s *Session) GuildChannelCreate(guildID, name, ctype string) (st *Channel,
return return
} }
// GuildChannelCreate creates a new channel in the given guild
// guildID : The ID of a Guild.
// name : Name of the channel (2-100 chars length)
// ctype : Type of the channel
func (s *Session) GuildChannelCreate(guildID, name string, ctype ChannelType) (st *Channel, err error) {
return s.GuildChannelCreateComplex(guildID, GuildChannelCreateData{
Name: name,
Type: ctype,
})
}
// GuildChannelsReorder updates the order of channels in a guild // GuildChannelsReorder updates the order of channels in a guild
// guildID : The ID of a Guild. // guildID : The ID of a Guild.
// channels : Updated channels. // channels : Updated channels.
func (s *Session) GuildChannelsReorder(guildID string, channels []*Channel) (err error) { func (s *Session) GuildChannelsReorder(guildID string, channels []*Channel) (err error) {
_, err = s.RequestWithBucketID("PATCH", EndpointGuildChannels(guildID), channels, EndpointGuildChannels(guildID)) data := make([]struct {
ID string `json:"id"`
Position int `json:"position"`
}, len(channels))
for i, c := range channels {
data[i].ID = c.ID
data[i].Position = c.Position
}
_, err = s.RequestWithBucketID("PATCH", EndpointGuildChannels(guildID), data, EndpointGuildChannels(guildID))
return return
} }
@ -1021,7 +1089,7 @@ func (s *Session) GuildPruneCount(guildID string, days uint32) (count uint32, er
Pruned uint32 `json:"pruned"` Pruned uint32 `json:"pruned"`
}{} }{}
uri := EndpointGuildPrune(guildID) + fmt.Sprintf("?days=%d", days) uri := EndpointGuildPrune(guildID) + "?days=" + strconv.FormatUint(uint64(days), 10)
body, err := s.RequestWithBucketID("GET", uri, nil, EndpointGuildPrune(guildID)) body, err := s.RequestWithBucketID("GET", uri, nil, EndpointGuildPrune(guildID))
if err != nil { if err != nil {
return return
@ -1075,7 +1143,7 @@ func (s *Session) GuildPrune(guildID string, days uint32) (count uint32, err err
// GuildIntegrations returns an array of Integrations for a guild. // GuildIntegrations returns an array of Integrations for a guild.
// guildID : The ID of a Guild. // guildID : The ID of a Guild.
func (s *Session) GuildIntegrations(guildID string) (st []*GuildIntegration, err error) { func (s *Session) GuildIntegrations(guildID string) (st []*Integration, err error) {
body, err := s.RequestWithBucketID("GET", EndpointGuildIntegrations(guildID), nil, EndpointGuildIntegrations(guildID)) body, err := s.RequestWithBucketID("GET", EndpointGuildIntegrations(guildID), nil, EndpointGuildIntegrations(guildID))
if err != nil { if err != nil {
@ -1206,6 +1274,94 @@ func (s *Session) GuildEmbedEdit(guildID string, enabled bool, channelID string)
return return
} }
// GuildAuditLog returns the audit log for a Guild.
// guildID : The ID of a Guild.
// userID : If provided the log will be filtered for the given ID.
// beforeID : If provided all log entries returned will be before the given ID.
// actionType : If provided the log will be filtered for the given Action Type.
// limit : The number messages that can be returned. (default 50, min 1, max 100)
func (s *Session) GuildAuditLog(guildID, userID, beforeID string, actionType, limit int) (st *GuildAuditLog, err error) {
uri := EndpointGuildAuditLogs(guildID)
v := url.Values{}
if userID != "" {
v.Set("user_id", userID)
}
if beforeID != "" {
v.Set("before", beforeID)
}
if actionType > 0 {
v.Set("action_type", strconv.Itoa(actionType))
}
if limit > 0 {
v.Set("limit", strconv.Itoa(limit))
}
if len(v) > 0 {
uri = fmt.Sprintf("%s?%s", uri, v.Encode())
}
body, err := s.RequestWithBucketID("GET", uri, nil, EndpointGuildAuditLogs(guildID))
if err != nil {
return
}
err = unmarshal(body, &st)
return
}
// GuildEmojiCreate creates a new emoji
// guildID : The ID of a Guild.
// name : The Name of the Emoji.
// image : The base64 encoded emoji image, has to be smaller than 256KB.
// roles : The roles for which this emoji will be whitelisted, can be nil.
func (s *Session) GuildEmojiCreate(guildID, name, image string, roles []string) (emoji *Emoji, err error) {
data := struct {
Name string `json:"name"`
Image string `json:"image"`
Roles []string `json:"roles,omitempty"`
}{name, image, roles}
body, err := s.RequestWithBucketID("POST", EndpointGuildEmojis(guildID), data, EndpointGuildEmojis(guildID))
if err != nil {
return
}
err = unmarshal(body, &emoji)
return
}
// GuildEmojiEdit modifies an emoji
// guildID : The ID of a Guild.
// emojiID : The ID of an Emoji.
// name : The Name of the Emoji.
// roles : The roles for which this emoji will be whitelisted, can be nil.
func (s *Session) GuildEmojiEdit(guildID, emojiID, name string, roles []string) (emoji *Emoji, err error) {
data := struct {
Name string `json:"name"`
Roles []string `json:"roles,omitempty"`
}{name, roles}
body, err := s.RequestWithBucketID("PATCH", EndpointGuildEmoji(guildID, emojiID), data, EndpointGuildEmojis(guildID))
if err != nil {
return
}
err = unmarshal(body, &emoji)
return
}
// GuildEmojiDelete deletes an Emoji.
// guildID : The ID of a Guild.
// emojiID : The ID of an Emoji.
func (s *Session) GuildEmojiDelete(guildID, emojiID string) (err error) {
_, err = s.RequestWithBucketID("DELETE", EndpointGuildEmoji(guildID, emojiID), nil, EndpointGuildEmojis(guildID))
return
}
// ------------------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------------------
// Functions specific to Discord Channels // Functions specific to Discord Channels
// ------------------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------------------
@ -1291,7 +1447,7 @@ func (s *Session) ChannelMessages(channelID string, limit int, beforeID, afterID
v.Set("around", aroundID) v.Set("around", aroundID)
} }
if len(v) > 0 { if len(v) > 0 {
uri = fmt.Sprintf("%s?%s", uri, v.Encode()) uri += "?" + v.Encode()
} }
body, err := s.RequestWithBucketID("GET", uri, nil, EndpointChannelMessages(channelID)) body, err := s.RequestWithBucketID("GET", uri, nil, EndpointChannelMessages(channelID))
@ -1586,7 +1742,8 @@ func (s *Session) ChannelInviteCreate(channelID string, i Invite) (st *Invite, e
MaxAge int `json:"max_age"` MaxAge int `json:"max_age"`
MaxUses int `json:"max_uses"` MaxUses int `json:"max_uses"`
Temporary bool `json:"temporary"` Temporary bool `json:"temporary"`
}{i.MaxAge, i.MaxUses, i.Temporary} Unique bool `json:"unique"`
}{i.MaxAge, i.MaxUses, i.Temporary, i.Unique}
body, err := s.RequestWithBucketID("POST", EndpointChannelInvites(channelID), data, EndpointChannelInvites(channelID)) body, err := s.RequestWithBucketID("POST", EndpointChannelInvites(channelID), data, EndpointChannelInvites(channelID))
if err != nil { if err != nil {
@ -1638,6 +1795,19 @@ func (s *Session) Invite(inviteID string) (st *Invite, err error) {
return return
} }
// InviteWithCounts returns an Invite structure of the given invite including approximate member counts
// inviteID : The invite code
func (s *Session) InviteWithCounts(inviteID string) (st *Invite, err error) {
body, err := s.RequestWithBucketID("GET", EndpointInvite(inviteID)+"?with_counts=true", nil, EndpointInvite(""))
if err != nil {
return
}
err = unmarshal(body, &st)
return
}
// InviteDelete deletes an existing invite // InviteDelete deletes an existing invite
// inviteID : the code of an invite // inviteID : the code of an invite
func (s *Session) InviteDelete(inviteID string) (st *Invite, err error) { func (s *Session) InviteDelete(inviteID string) (st *Invite, err error) {
@ -1830,12 +2000,13 @@ func (s *Session) WebhookWithToken(webhookID, token string) (st *Webhook, err er
// webhookID: The ID of a webhook. // webhookID: The ID of a webhook.
// name : The name of the webhook. // name : The name of the webhook.
// avatar : The avatar of the webhook. // avatar : The avatar of the webhook.
func (s *Session) WebhookEdit(webhookID, name, avatar string) (st *Role, err error) { func (s *Session) WebhookEdit(webhookID, name, avatar, channelID string) (st *Role, err error) {
data := struct { data := struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Avatar string `json:"avatar,omitempty"` Avatar string `json:"avatar,omitempty"`
}{name, avatar} ChannelID string `json:"channel_id,omitempty"`
}{name, avatar, channelID}
body, err := s.RequestWithBucketID("PATCH", EndpointWebhook(webhookID), data, EndpointWebhooks) body, err := s.RequestWithBucketID("PATCH", EndpointWebhook(webhookID), data, EndpointWebhooks)
if err != nil { if err != nil {
@ -1956,7 +2127,7 @@ func (s *Session) MessageReactions(channelID, messageID, emojiID string, limit i
} }
if len(v) > 0 { if len(v) > 0 {
uri = fmt.Sprintf("%s?%s", uri, v.Encode()) uri += "?" + v.Encode()
} }
body, err := s.RequestWithBucketID("GET", uri, nil, EndpointMessageReaction(channelID, "", "", "")) body, err := s.RequestWithBucketID("GET", uri, nil, EndpointMessageReaction(channelID, "", "", ""))

View File

@ -32,6 +32,7 @@ type State struct {
sync.RWMutex sync.RWMutex
Ready Ready
// MaxMessageCount represents how many messages per channel the state will store.
MaxMessageCount int MaxMessageCount int
TrackChannels bool TrackChannels bool
TrackEmojis bool TrackEmojis bool
@ -98,6 +99,9 @@ func (s *State) GuildAdd(guild *Guild) error {
if g, ok := s.guildMap[guild.ID]; ok { if g, ok := s.guildMap[guild.ID]; ok {
// We are about to replace `g` in the state with `guild`, but first we need to // We are about to replace `g` in the state with `guild`, but first we need to
// make sure we preserve any fields that the `guild` doesn't contain from `g`. // make sure we preserve any fields that the `guild` doesn't contain from `g`.
if guild.MemberCount == 0 {
guild.MemberCount = g.MemberCount
}
if guild.Roles == nil { if guild.Roles == nil {
guild.Roles = g.Roles guild.Roles = g.Roles
} }
@ -299,7 +303,12 @@ func (s *State) MemberAdd(member *Member) error {
members[member.User.ID] = member members[member.User.ID] = member
guild.Members = append(guild.Members, member) guild.Members = append(guild.Members, member)
} else { } else {
*m = *member // Update the actual data, which will also update the member pointer in the slice // We are about to replace `m` in the state with `member`, but first we need to
// make sure we preserve any fields that the `member` doesn't contain from `m`.
if member.JoinedAt == "" {
member.JoinedAt = m.JoinedAt
}
*m = *member
} }
return nil return nil
@ -607,7 +616,7 @@ func (s *State) EmojisAdd(guildID string, emojis []*Emoji) error {
// MessageAdd adds a message to the current world state, or updates it if it exists. // MessageAdd adds a message to the current world state, or updates it if it exists.
// If the channel cannot be found, the message is discarded. // If the channel cannot be found, the message is discarded.
// Messages are kept in state up to s.MaxMessageCount // Messages are kept in state up to s.MaxMessageCount per channel.
func (s *State) MessageAdd(message *Message) error { func (s *State) MessageAdd(message *Message) error {
if s == nil { if s == nil {
return ErrNilState return ErrNilState
@ -805,6 +814,14 @@ func (s *State) OnInterface(se *Session, i interface{}) (err error) {
case *GuildDelete: case *GuildDelete:
err = s.GuildRemove(t.Guild) err = s.GuildRemove(t.Guild)
case *GuildMemberAdd: case *GuildMemberAdd:
// Updates the MemberCount of the guild.
guild, err := s.Guild(t.Member.GuildID)
if err != nil {
return err
}
guild.MemberCount++
// Caches member if tracking is enabled.
if s.TrackMembers { if s.TrackMembers {
err = s.MemberAdd(t.Member) err = s.MemberAdd(t.Member)
} }
@ -813,6 +830,14 @@ func (s *State) OnInterface(se *Session, i interface{}) (err error) {
err = s.MemberAdd(t.Member) err = s.MemberAdd(t.Member)
} }
case *GuildMemberRemove: case *GuildMemberRemove:
// Updates the MemberCount of the guild.
guild, err := s.Guild(t.Member.GuildID)
if err != nil {
return err
}
guild.MemberCount--
// Removes member from the cache if tracking is enabled.
if s.TrackMembers { if s.TrackMembers {
err = s.MemberRemove(t.Member) err = s.MemberRemove(t.Member)
} }

View File

@ -13,6 +13,7 @@ package discordgo
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -84,6 +85,9 @@ type Session struct {
// Stores the last HeartbeatAck that was recieved (in UTC) // Stores the last HeartbeatAck that was recieved (in UTC)
LastHeartbeatAck time.Time LastHeartbeatAck time.Time
// Stores the last Heartbeat sent (in UTC)
LastHeartbeatSent time.Time
// used to deal with rate limits // used to deal with rate limits
Ratelimiter *RateLimiter Ratelimiter *RateLimiter
@ -111,6 +115,37 @@ type Session struct {
wsMutex sync.Mutex wsMutex sync.Mutex
} }
// UserConnection is a Connection returned from the UserConnections endpoint
type UserConnection struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Revoked bool `json:"revoked"`
Integrations []*Integration `json:"integrations"`
}
// Integration stores integration information
type Integration struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Enabled bool `json:"enabled"`
Syncing bool `json:"syncing"`
RoleID string `json:"role_id"`
ExpireBehavior int `json:"expire_behavior"`
ExpireGracePeriod int `json:"expire_grace_period"`
User *User `json:"user"`
Account IntegrationAccount `json:"account"`
SyncedAt Timestamp `json:"synced_at"`
}
// IntegrationAccount is integration account information
// sent by the UserConnections endpoint
type IntegrationAccount struct {
ID string `json:"id"`
Name string `json:"name"`
}
// A VoiceRegion stores data for a specific voice region server. // A VoiceRegion stores data for a specific voice region server.
type VoiceRegion struct { type VoiceRegion struct {
ID string `json:"id"` ID string `json:"id"`
@ -145,6 +180,10 @@ type Invite struct {
Revoked bool `json:"revoked"` Revoked bool `json:"revoked"`
Temporary bool `json:"temporary"` Temporary bool `json:"temporary"`
Unique bool `json:"unique"` Unique bool `json:"unique"`
// will only be filled when using InviteWithCounts
ApproximatePresenceCount int `json:"approximate_presence_count"`
ApproximateMemberCount int `json:"approximate_member_count"`
} }
// ChannelType is the type of a Channel // ChannelType is the type of a Channel
@ -161,22 +200,61 @@ const (
// A Channel holds all data related to an individual Discord channel. // A Channel holds all data related to an individual Discord channel.
type Channel struct { type Channel struct {
// The ID of the channel.
ID string `json:"id"` ID string `json:"id"`
// The ID of the guild to which the channel belongs, if it is in a guild.
// Else, this ID is empty (e.g. DM channels).
GuildID string `json:"guild_id"` GuildID string `json:"guild_id"`
// The name of the channel.
Name string `json:"name"` Name string `json:"name"`
// The topic of the channel.
Topic string `json:"topic"` Topic string `json:"topic"`
// The type of the channel.
Type ChannelType `json:"type"` Type ChannelType `json:"type"`
// The ID of the last message sent in the channel. This is not
// guaranteed to be an ID of a valid message.
LastMessageID string `json:"last_message_id"` LastMessageID string `json:"last_message_id"`
// Whether the channel is marked as NSFW.
NSFW bool `json:"nsfw"` NSFW bool `json:"nsfw"`
// Icon of the group DM channel.
Icon string `json:"icon"`
// The position of the channel, used for sorting in client.
Position int `json:"position"` Position int `json:"position"`
// The bitrate of the channel, if it is a voice channel.
Bitrate int `json:"bitrate"` Bitrate int `json:"bitrate"`
// The recipients of the channel. This is only populated in DM channels.
Recipients []*User `json:"recipients"` Recipients []*User `json:"recipients"`
// The messages in the channel. This is only present in state-cached channels,
// and State.MaxMessageCount must be non-zero.
Messages []*Message `json:"-"` Messages []*Message `json:"-"`
// A list of permission overwrites present for the channel.
PermissionOverwrites []*PermissionOverwrite `json:"permission_overwrites"` PermissionOverwrites []*PermissionOverwrite `json:"permission_overwrites"`
// The user limit of the voice channel.
UserLimit int `json:"user_limit"`
// The ID of the parent channel, if the channel is under a category
ParentID string `json:"parent_id"` ParentID string `json:"parent_id"`
} }
// A ChannelEdit holds Channel Feild data for a channel edit. // Mention returns a string which mentions the channel
func (c *Channel) Mention() string {
return fmt.Sprintf("<#%s>", c.ID)
}
// A ChannelEdit holds Channel Field data for a channel edit.
type ChannelEdit struct { type ChannelEdit struct {
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Topic string `json:"topic,omitempty"` Topic string `json:"topic,omitempty"`
@ -186,6 +264,7 @@ type ChannelEdit struct {
UserLimit int `json:"user_limit,omitempty"` UserLimit int `json:"user_limit,omitempty"`
PermissionOverwrites []*PermissionOverwrite `json:"permission_overwrites,omitempty"` PermissionOverwrites []*PermissionOverwrite `json:"permission_overwrites,omitempty"`
ParentID string `json:"parent_id,omitempty"` ParentID string `json:"parent_id,omitempty"`
RateLimitPerUser int `json:"rate_limit_per_user,omitempty"`
} }
// A PermissionOverwrite holds permission overwrite data for a Channel // A PermissionOverwrite holds permission overwrite data for a Channel
@ -206,6 +285,19 @@ type Emoji struct {
Animated bool `json:"animated"` Animated bool `json:"animated"`
} }
// MessageFormat returns a correctly formatted Emoji for use in Message content and embeds
func (e *Emoji) MessageFormat() string {
if e.ID != "" && e.Name != "" {
if e.Animated {
return "<a:" + e.APIName() + ">"
}
return "<:" + e.APIName() + ">"
}
return e.APIName()
}
// APIName returns an correctly formatted API name for use in the MessageReactions endpoints. // APIName returns an correctly formatted API name for use in the MessageReactions endpoints.
func (e *Emoji) APIName() string { func (e *Emoji) APIName() string {
if e.ID != "" && e.Name != "" { if e.ID != "" && e.Name != "" {
@ -228,31 +320,129 @@ const (
VerificationLevelHigh VerificationLevelHigh
) )
// ExplicitContentFilterLevel type definition
type ExplicitContentFilterLevel int
// Constants for ExplicitContentFilterLevel levels from 0 to 2 inclusive
const (
ExplicitContentFilterDisabled ExplicitContentFilterLevel = iota
ExplicitContentFilterMembersWithoutRoles
ExplicitContentFilterAllMembers
)
// MfaLevel type definition
type MfaLevel int
// Constants for MfaLevel levels from 0 to 1 inclusive
const (
MfaLevelNone MfaLevel = iota
MfaLevelElevated
)
// A Guild holds all data related to a specific Discord Guild. Guilds are also // A Guild holds all data related to a specific Discord Guild. Guilds are also
// sometimes referred to as Servers in the Discord client. // sometimes referred to as Servers in the Discord client.
type Guild struct { type Guild struct {
// The ID of the guild.
ID string `json:"id"` ID string `json:"id"`
// The name of the guild. (2100 characters)
Name string `json:"name"` Name string `json:"name"`
// The hash of the guild's icon. Use Session.GuildIcon
// to retrieve the icon itself.
Icon string `json:"icon"` Icon string `json:"icon"`
// The voice region of the guild.
Region string `json:"region"` Region string `json:"region"`
// The ID of the AFK voice channel.
AfkChannelID string `json:"afk_channel_id"` AfkChannelID string `json:"afk_channel_id"`
// The ID of the embed channel ID, used for embed widgets.
EmbedChannelID string `json:"embed_channel_id"` EmbedChannelID string `json:"embed_channel_id"`
// The user ID of the owner of the guild.
OwnerID string `json:"owner_id"` OwnerID string `json:"owner_id"`
// The time at which the current user joined the guild.
// This field is only present in GUILD_CREATE events and websocket
// update events, and thus is only present in state-cached guilds.
JoinedAt Timestamp `json:"joined_at"` JoinedAt Timestamp `json:"joined_at"`
// The hash of the guild's splash.
Splash string `json:"splash"` Splash string `json:"splash"`
// The timeout, in seconds, before a user is considered AFK in voice.
AfkTimeout int `json:"afk_timeout"` AfkTimeout int `json:"afk_timeout"`
// The number of members in the guild.
// This field is only present in GUILD_CREATE events and websocket
// update events, and thus is only present in state-cached guilds.
MemberCount int `json:"member_count"` MemberCount int `json:"member_count"`
// The verification level required for the guild.
VerificationLevel VerificationLevel `json:"verification_level"` VerificationLevel VerificationLevel `json:"verification_level"`
// Whether the guild has embedding enabled.
EmbedEnabled bool `json:"embed_enabled"` EmbedEnabled bool `json:"embed_enabled"`
Large bool `json:"large"` // ??
// Whether the guild is considered large. This is
// determined by a member threshold in the identify packet,
// and is currently hard-coded at 250 members in the library.
Large bool `json:"large"`
// The default message notification setting for the guild.
// 0 == all messages, 1 == mentions only.
DefaultMessageNotifications int `json:"default_message_notifications"` DefaultMessageNotifications int `json:"default_message_notifications"`
// A list of roles in the guild.
Roles []*Role `json:"roles"` Roles []*Role `json:"roles"`
// A list of the custom emojis present in the guild.
Emojis []*Emoji `json:"emojis"` Emojis []*Emoji `json:"emojis"`
// A list of the members in the guild.
// This field is only present in GUILD_CREATE events and websocket
// update events, and thus is only present in state-cached guilds.
Members []*Member `json:"members"` Members []*Member `json:"members"`
// A list of partial presence objects for members in the guild.
// This field is only present in GUILD_CREATE events and websocket
// update events, and thus is only present in state-cached guilds.
Presences []*Presence `json:"presences"` Presences []*Presence `json:"presences"`
// A list of channels in the guild.
// This field is only present in GUILD_CREATE events and websocket
// update events, and thus is only present in state-cached guilds.
Channels []*Channel `json:"channels"` Channels []*Channel `json:"channels"`
// A list of voice states for the guild.
// This field is only present in GUILD_CREATE events and websocket
// update events, and thus is only present in state-cached guilds.
VoiceStates []*VoiceState `json:"voice_states"` VoiceStates []*VoiceState `json:"voice_states"`
// Whether this guild is currently unavailable (most likely due to outage).
// This field is only present in GUILD_CREATE events and websocket
// update events, and thus is only present in state-cached guilds.
Unavailable bool `json:"unavailable"` Unavailable bool `json:"unavailable"`
// The explicit content filter level
ExplicitContentFilter ExplicitContentFilterLevel `json:"explicit_content_filter"`
// The list of enabled guild features
Features []string `json:"features"`
// Required MFA level for the guild
MfaLevel MfaLevel `json:"mfa_level"`
// Whether or not the Server Widget is enabled
WidgetEnabled bool `json:"widget_enabled"`
// The Channel ID for the Server Widget
WidgetChannelID string `json:"widget_channel_id"`
// The Channel ID to which system messages are sent (eg join and leave messages)
SystemChannelID string `json:"system_channel_id"`
} }
// A UserGuild holds a brief version of a Guild // A UserGuild holds a brief version of a Guild
@ -279,16 +469,39 @@ type GuildParams struct {
// A Role stores information about Discord guild member roles. // A Role stores information about Discord guild member roles.
type Role struct { type Role struct {
// The ID of the role.
ID string `json:"id"` ID string `json:"id"`
// The name of the role.
Name string `json:"name"` Name string `json:"name"`
// Whether this role is managed by an integration, and
// thus cannot be manually added to, or taken from, members.
Managed bool `json:"managed"` Managed bool `json:"managed"`
// Whether this role is mentionable.
Mentionable bool `json:"mentionable"` Mentionable bool `json:"mentionable"`
// Whether this role is hoisted (shows up separately in member list).
Hoist bool `json:"hoist"` Hoist bool `json:"hoist"`
// The hex color of this role.
Color int `json:"color"` Color int `json:"color"`
// The position of this role in the guild's role hierarchy.
Position int `json:"position"` Position int `json:"position"`
// The permissions of the role on the guild (doesn't include channel overrides).
// This is a combination of bit masks; the presence of a certain permission can
// be checked by performing a bitwise AND between this int and the permission.
Permissions int `json:"permissions"` Permissions int `json:"permissions"`
} }
// Mention returns a string which mentions the role
func (r *Role) Mention() string {
return fmt.Sprintf("<@&%s>", r.ID)
}
// Roles are a collection of Role // Roles are a collection of Role
type Roles []*Role type Roles []*Role
@ -334,6 +547,8 @@ type GameType int
const ( const (
GameTypeGame GameType = iota GameTypeGame GameType = iota
GameTypeStreaming GameTypeStreaming
GameTypeListening
GameTypeWatching
) )
// A Game struct holds the name of the "playing .." game for a user // A Game struct holds the name of the "playing .." game for a user
@ -379,17 +594,36 @@ type Assets struct {
SmallText string `json:"small_text,omitempty"` SmallText string `json:"small_text,omitempty"`
} }
// A Member stores user information for Guild members. // A Member stores user information for Guild members. A guild
// member represents a certain user's presence in a guild.
type Member struct { type Member struct {
// The guild ID on which the member exists.
GuildID string `json:"guild_id"` GuildID string `json:"guild_id"`
JoinedAt string `json:"joined_at"`
// The time at which the member joined the guild, in ISO8601.
JoinedAt Timestamp `json:"joined_at"`
// The nickname of the member, if they have one.
Nick string `json:"nick"` Nick string `json:"nick"`
// Whether the member is deafened at a guild level.
Deaf bool `json:"deaf"` Deaf bool `json:"deaf"`
// Whether the member is muted at a guild level.
Mute bool `json:"mute"` Mute bool `json:"mute"`
// The underlying user on which the member is based.
User *User `json:"user"` User *User `json:"user"`
// A list of IDs of the roles which are possessed by the member.
Roles []string `json:"roles"` Roles []string `json:"roles"`
} }
// Mention creates a member mention
func (m *Member) Mention() string {
return "<@!" + m.User.ID + ">"
}
// A Settings stores data for a specific users Discord client settings. // A Settings stores data for a specific users Discord client settings.
type Settings struct { type Settings struct {
RenderEmbeds bool `json:"render_embeds"` RenderEmbeds bool `json:"render_embeds"`
@ -467,33 +701,88 @@ type GuildBan struct {
User *User `json:"user"` User *User `json:"user"`
} }
// A GuildIntegration stores data for a guild integration.
type GuildIntegration struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Enabled bool `json:"enabled"`
Syncing bool `json:"syncing"`
RoleID string `json:"role_id"`
ExpireBehavior int `json:"expire_behavior"`
ExpireGracePeriod int `json:"expire_grace_period"`
User *User `json:"user"`
Account *GuildIntegrationAccount `json:"account"`
SyncedAt int `json:"synced_at"`
}
// A GuildIntegrationAccount stores data for a guild integration account.
type GuildIntegrationAccount struct {
ID string `json:"id"`
Name string `json:"name"`
}
// A GuildEmbed stores data for a guild embed. // A GuildEmbed stores data for a guild embed.
type GuildEmbed struct { type GuildEmbed struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
} }
// A GuildAuditLog stores data for a guild audit log.
type GuildAuditLog struct {
Webhooks []struct {
ChannelID string `json:"channel_id"`
GuildID string `json:"guild_id"`
ID string `json:"id"`
Avatar string `json:"avatar"`
Name string `json:"name"`
} `json:"webhooks,omitempty"`
Users []struct {
Username string `json:"username"`
Discriminator string `json:"discriminator"`
Bot bool `json:"bot"`
ID string `json:"id"`
Avatar string `json:"avatar"`
} `json:"users,omitempty"`
AuditLogEntries []struct {
TargetID string `json:"target_id"`
Changes []struct {
NewValue interface{} `json:"new_value"`
OldValue interface{} `json:"old_value"`
Key string `json:"key"`
} `json:"changes,omitempty"`
UserID string `json:"user_id"`
ID string `json:"id"`
ActionType int `json:"action_type"`
Options struct {
DeleteMembersDay string `json:"delete_member_days"`
MembersRemoved string `json:"members_removed"`
ChannelID string `json:"channel_id"`
Count string `json:"count"`
ID string `json:"id"`
Type string `json:"type"`
RoleName string `json:"role_name"`
} `json:"options,omitempty"`
Reason string `json:"reason"`
} `json:"audit_log_entries"`
}
// Block contains Discord Audit Log Action Types
const (
AuditLogActionGuildUpdate = 1
AuditLogActionChannelCreate = 10
AuditLogActionChannelUpdate = 11
AuditLogActionChannelDelete = 12
AuditLogActionChannelOverwriteCreate = 13
AuditLogActionChannelOverwriteUpdate = 14
AuditLogActionChannelOverwriteDelete = 15
AuditLogActionMemberKick = 20
AuditLogActionMemberPrune = 21
AuditLogActionMemberBanAdd = 22
AuditLogActionMemberBanRemove = 23
AuditLogActionMemberUpdate = 24
AuditLogActionMemberRoleUpdate = 25
AuditLogActionRoleCreate = 30
AuditLogActionRoleUpdate = 31
AuditLogActionRoleDelete = 32
AuditLogActionInviteCreate = 40
AuditLogActionInviteUpdate = 41
AuditLogActionInviteDelete = 42
AuditLogActionWebhookCreate = 50
AuditLogActionWebhookUpdate = 51
AuditLogActionWebhookDelete = 52
AuditLogActionEmojiCreate = 60
AuditLogActionEmojiUpdate = 61
AuditLogActionEmojiDelete = 62
AuditLogActionMessageDelete = 72
)
// A UserGuildSettingsChannelOverride stores data for a channel override for a users guild settings. // A UserGuildSettingsChannelOverride stores data for a channel override for a users guild settings.
type UserGuildSettingsChannelOverride struct { type UserGuildSettingsChannelOverride struct {
Muted bool `json:"muted"` Muted bool `json:"muted"`
@ -553,6 +842,7 @@ type MessageReaction struct {
MessageID string `json:"message_id"` MessageID string `json:"message_id"`
Emoji Emoji `json:"emoji"` Emoji Emoji `json:"emoji"`
ChannelID string `json:"channel_id"` ChannelID string `json:"channel_id"`
GuildID string `json:"guild_id,omitempty"`
} }
// GatewayBotResponse stores the data for the gateway/bot response // GatewayBotResponse stores the data for the gateway/bot response
@ -629,7 +919,9 @@ const (
PermissionKickMembers | PermissionKickMembers |
PermissionBanMembers | PermissionBanMembers |
PermissionManageServer | PermissionManageServer |
PermissionAdministrator PermissionAdministrator |
PermissionManageWebhooks |
PermissionManageEmojis
) )
// Block contains Discord JSON Error Response codes // Block contains Discord JSON Error Response codes
@ -648,6 +940,7 @@ const (
ErrCodeUnknownToken = 10012 ErrCodeUnknownToken = 10012
ErrCodeUnknownUser = 10013 ErrCodeUnknownUser = 10013
ErrCodeUnknownEmoji = 10014 ErrCodeUnknownEmoji = 10014
ErrCodeUnknownWebhook = 10015
ErrCodeBotsCannotUseEndpoint = 20001 ErrCodeBotsCannotUseEndpoint = 20001
ErrCodeOnlyBotsCanUseEndpoint = 20002 ErrCodeOnlyBotsCanUseEndpoint = 20002

View File

@ -11,7 +11,6 @@ package discordgo
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"time" "time"
) )
@ -54,5 +53,5 @@ func newRestError(req *http.Request, resp *http.Response, body []byte) *RESTErro
} }
func (r RESTError) Error() string { func (r RESTError) Error() string {
return fmt.Sprintf("HTTP %s, %s", r.Response.Status, r.ResponseBody) return "HTTP " + r.Response.Status + ", " + string(r.ResponseBody)
} }

69
vendor/github.com/bwmarrin/discordgo/user.go generated vendored Normal file
View File

@ -0,0 +1,69 @@
package discordgo
import "strings"
// A User stores all data for an individual Discord user.
type User struct {
// The ID of the user.
ID string `json:"id"`
// The email of the user. This is only present when
// the application possesses the email scope for the user.
Email string `json:"email"`
// The user's username.
Username string `json:"username"`
// The hash of the user's avatar. Use Session.UserAvatar
// to retrieve the avatar itself.
Avatar string `json:"avatar"`
// The user's chosen language option.
Locale string `json:"locale"`
// The discriminator of the user (4 numbers after name).
Discriminator string `json:"discriminator"`
// The token of the user. This is only present for
// the user represented by the current session.
Token string `json:"token"`
// Whether the user's email is verified.
Verified bool `json:"verified"`
// Whether the user has multi-factor authentication enabled.
MFAEnabled bool `json:"mfa_enabled"`
// Whether the user is a bot.
Bot bool `json:"bot"`
}
// String returns a unique identifier of the form username#discriminator
func (u *User) String() string {
return u.Username + "#" + u.Discriminator
}
// Mention return a string which mentions the user
func (u *User) Mention() string {
return "<@" + u.ID + ">"
}
// AvatarURL returns a URL to the user's avatar.
// size: The size of the user's avatar as a power of two
// if size is an empty string, no size parameter will
// be added to the URL.
func (u *User) AvatarURL(size string) string {
var URL string
if u.Avatar == "" {
URL = EndpointDefaultUserAvatar(u.Discriminator)
} else if strings.HasPrefix(u.Avatar, "a_") {
URL = EndpointUserAvatarAnimated(u.ID, u.Avatar)
} else {
URL = EndpointUserAvatar(u.ID, u.Avatar)
}
if size != "" {
return URL + "?size=" + size
}
return URL
}

View File

@ -14,6 +14,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -103,7 +104,7 @@ func (v *VoiceConnection) Speaking(b bool) (err error) {
defer v.Unlock() defer v.Unlock()
if err != nil { if err != nil {
v.speaking = false v.speaking = false
v.log(LogError, "Speaking() write json error:", err) v.log(LogError, "Speaking() write json error, %s", err)
return return
} }
@ -135,7 +136,6 @@ func (v *VoiceConnection) ChangeChannel(channelID string, mute, deaf bool) (err
// Disconnect disconnects from this voice channel and closes the websocket // Disconnect disconnects from this voice channel and closes the websocket
// and udp connections to Discord. // and udp connections to Discord.
// !!! NOTE !!! this function may be removed in favour of ChannelVoiceLeave
func (v *VoiceConnection) Disconnect() (err error) { func (v *VoiceConnection) Disconnect() (err error) {
// Send a OP4 with a nil channel to disconnect // Send a OP4 with a nil channel to disconnect
@ -180,7 +180,7 @@ func (v *VoiceConnection) Close() {
v.log(LogInformational, "closing udp") v.log(LogInformational, "closing udp")
err := v.udpConn.Close() err := v.udpConn.Close()
if err != nil { if err != nil {
v.log(LogError, "error closing udp connection: ", err) v.log(LogError, "error closing udp connection, %s", err)
} }
v.udpConn = nil v.udpConn = nil
} }
@ -299,7 +299,7 @@ func (v *VoiceConnection) open() (err error) {
} }
// Connect to VoiceConnection Websocket // Connect to VoiceConnection Websocket
vg := fmt.Sprintf("wss://%s", strings.TrimSuffix(v.endpoint, ":80")) vg := "wss://" + strings.TrimSuffix(v.endpoint, ":80")
v.log(LogInformational, "connecting to voice endpoint %s", vg) v.log(LogInformational, "connecting to voice endpoint %s", vg)
v.wsConn, _, err = websocket.DefaultDialer.Dial(vg, nil) v.wsConn, _, err = websocket.DefaultDialer.Dial(vg, nil)
if err != nil { if err != nil {
@ -542,7 +542,7 @@ func (v *VoiceConnection) udpOpen() (err error) {
return fmt.Errorf("empty endpoint") return fmt.Errorf("empty endpoint")
} }
host := fmt.Sprintf("%s:%d", strings.TrimSuffix(v.endpoint, ":80"), v.op2.Port) host := strings.TrimSuffix(v.endpoint, ":80") + ":" + strconv.Itoa(v.op2.Port)
addr, err := net.ResolveUDPAddr("udp", host) addr, err := net.ResolveUDPAddr("udp", host)
if err != nil { if err != nil {
v.log(LogWarning, "error resolving udp host %s, %s", host, err) v.log(LogWarning, "error resolving udp host %s, %s", host, err)

View File

@ -86,6 +86,10 @@ func (s *Session) Open() error {
return err return err
} }
s.wsConn.SetCloseHandler(func(code int, text string) error {
return nil
})
defer func() { defer func() {
// because of this, all code below must set err to the error // because of this, all code below must set err to the error
// when exiting with an error :) Maybe someone has a better // when exiting with an error :) Maybe someone has a better
@ -263,6 +267,13 @@ type helloOp struct {
// FailedHeartbeatAcks is the Number of heartbeat intervals to wait until forcing a connection restart. // FailedHeartbeatAcks is the Number of heartbeat intervals to wait until forcing a connection restart.
const FailedHeartbeatAcks time.Duration = 5 * time.Millisecond const FailedHeartbeatAcks time.Duration = 5 * time.Millisecond
// HeartbeatLatency returns the latency between heartbeat acknowledgement and heartbeat send.
func (s *Session) HeartbeatLatency() time.Duration {
return s.LastHeartbeatAck.Sub(s.LastHeartbeatSent)
}
// heartbeat sends regular heartbeats to Discord so it knows the client // heartbeat sends regular heartbeats to Discord so it knows the client
// is still connected. If you do not send these heartbeats Discord will // is still connected. If you do not send these heartbeats Discord will
// disconnect the websocket connection after a few seconds. // disconnect the websocket connection after a few seconds.
@ -283,8 +294,9 @@ func (s *Session) heartbeat(wsConn *websocket.Conn, listening <-chan interface{}
last := s.LastHeartbeatAck last := s.LastHeartbeatAck
s.RUnlock() s.RUnlock()
sequence := atomic.LoadInt64(s.sequence) sequence := atomic.LoadInt64(s.sequence)
s.log(LogInformational, "sending gateway websocket heartbeat seq %d", sequence) s.log(LogDebug, "sending gateway websocket heartbeat seq %d", sequence)
s.wsMutex.Lock() s.wsMutex.Lock()
s.LastHeartbeatSent = time.Now().UTC()
err = wsConn.WriteJSON(heartbeatOp{1, sequence}) err = wsConn.WriteJSON(heartbeatOp{1, sequence})
s.wsMutex.Unlock() s.wsMutex.Unlock()
if err != nil || time.Now().UTC().Sub(last) > (heartbeatIntervalMsec*FailedHeartbeatAcks) { if err != nil || time.Now().UTC().Sub(last) > (heartbeatIntervalMsec*FailedHeartbeatAcks) {
@ -323,16 +335,8 @@ type updateStatusOp struct {
Data UpdateStatusData `json:"d"` Data UpdateStatusData `json:"d"`
} }
// UpdateStreamingStatus is used to update the user's streaming status. func newUpdateStatusData(idle int, gameType GameType, game, url string) *UpdateStatusData {
// If idle>0 then set status to idle. usd := &UpdateStatusData{
// If game!="" then set game.
// If game!="" and url!="" then set the status type to streaming with the URL set.
// if otherwise, set status to active, and no game.
func (s *Session) UpdateStreamingStatus(idle int, game string, url string) (err error) {
s.log(LogInformational, "called")
usd := UpdateStatusData{
Status: "online", Status: "online",
} }
@ -341,10 +345,6 @@ func (s *Session) UpdateStreamingStatus(idle int, game string, url string) (err
} }
if game != "" { if game != "" {
gameType := GameTypeGame
if url != "" {
gameType = GameTypeStreaming
}
usd.Game = &Game{ usd.Game = &Game{
Name: game, Name: game,
Type: gameType, Type: gameType,
@ -352,7 +352,35 @@ func (s *Session) UpdateStreamingStatus(idle int, game string, url string) (err
} }
} }
return s.UpdateStatusComplex(usd) return usd
}
// UpdateStatus is used to update the user's status.
// If idle>0 then set status to idle.
// If game!="" then set game.
// if otherwise, set status to active, and no game.
func (s *Session) UpdateStatus(idle int, game string) (err error) {
return s.UpdateStatusComplex(*newUpdateStatusData(idle, GameTypeGame, game, ""))
}
// UpdateStreamingStatus is used to update the user's streaming status.
// If idle>0 then set status to idle.
// If game!="" then set game.
// If game!="" and url!="" then set the status type to streaming with the URL set.
// if otherwise, set status to active, and no game.
func (s *Session) UpdateStreamingStatus(idle int, game string, url string) (err error) {
gameType := GameTypeGame
if url != "" {
gameType = GameTypeStreaming
}
return s.UpdateStatusComplex(*newUpdateStatusData(idle, gameType, game, url))
}
// UpdateListeningStatus is used to set the user to "Listening to..."
// If game!="" then set to what user is listening to
// Else, set user to active and no game.
func (s *Session) UpdateListeningStatus(game string) (err error) {
return s.UpdateStatusComplex(*newUpdateStatusData(0, GameTypeListening, game, ""))
} }
// UpdateStatusComplex allows for sending the raw status update data untouched by discordgo. // UpdateStatusComplex allows for sending the raw status update data untouched by discordgo.
@ -371,14 +399,6 @@ func (s *Session) UpdateStatusComplex(usd UpdateStatusData) (err error) {
return return
} }
// UpdateStatus is used to update the user's status.
// If idle>0 then set status to idle.
// If game!="" then set game.
// if otherwise, set status to active, and no game.
func (s *Session) UpdateStatus(idle int, game string) (err error) {
return s.UpdateStreamingStatus(idle, game, "")
}
type requestGuildMembersData struct { type requestGuildMembersData struct {
GuildID string `json:"guild_id"` GuildID string `json:"guild_id"`
Query string `json:"query"` Query string `json:"query"`
@ -508,7 +528,7 @@ func (s *Session) onEvent(messageType int, message []byte) (*Event, error) {
s.Lock() s.Lock()
s.LastHeartbeatAck = time.Now().UTC() s.LastHeartbeatAck = time.Now().UTC()
s.Unlock() s.Unlock()
s.log(LogInformational, "got heartbeat ACK") s.log(LogDebug, "got heartbeat ACK")
return e, nil return e, nil
} }
@ -615,6 +635,30 @@ func (s *Session) ChannelVoiceJoin(gID, cID string, mute, deaf bool) (voice *Voi
return return
} }
// ChannelVoiceJoinManual initiates a voice session to a voice channel, but does not complete it.
//
// This should only be used when the VoiceServerUpdate will be intercepted and used elsewhere.
//
// gID : Guild ID of the channel to join.
// cID : Channel ID of the channel to join.
// mute : If true, you will be set to muted upon joining.
// deaf : If true, you will be set to deafened upon joining.
func (s *Session) ChannelVoiceJoinManual(gID, cID string, mute, deaf bool) (err error) {
s.log(LogInformational, "called")
// Send the request to Discord that we want to join the voice channel
data := voiceChannelJoinOp{4, voiceChannelJoinData{&gID, &cID, mute, deaf}}
s.wsMutex.Lock()
err = s.wsConn.WriteJSON(data)
s.wsMutex.Unlock()
if err != nil {
return
}
return
}
// onVoiceStateUpdate handles Voice State Update events on the data websocket. // onVoiceStateUpdate handles Voice State Update events on the data websocket.
func (s *Session) onVoiceStateUpdate(st *VoiceStateUpdate) { func (s *Session) onVoiceStateUpdate(st *VoiceStateUpdate) {
@ -732,11 +776,8 @@ func (s *Session) identify() error {
s.wsMutex.Lock() s.wsMutex.Lock()
err := s.wsConn.WriteJSON(op) err := s.wsConn.WriteJSON(op)
s.wsMutex.Unlock() s.wsMutex.Unlock()
if err != nil {
return err
}
return nil return err
} }
func (s *Session) reconnect() { func (s *Session) reconnect() {

View File

@ -3,11 +3,11 @@ sudo: false
matrix: matrix:
include: include:
- go: 1.4 - go: 1.7.x
- go: 1.5 - go: 1.8.x
- go: 1.6 - go: 1.9.x
- go: 1.7 - go: 1.10.x
- go: 1.8 - go: 1.11.x
- go: tip - go: tip
allow_failures: allow_failures:
- go: tip - go: tip

View File

@ -4,5 +4,6 @@
# Please keep the list sorted. # Please keep the list sorted.
Gary Burd <gary@beagledreams.com> Gary Burd <gary@beagledreams.com>
Google LLC (https://opensource.google.com/)
Joachim Bauch <mail@joachim-bauch.de> Joachim Bauch <mail@joachim-bauch.de>

View File

@ -5,15 +5,15 @@
package websocket package websocket
import ( import (
"bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace"
"net/url" "net/url"
"strings" "strings"
"time" "time"
@ -53,6 +53,10 @@ type Dialer struct {
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error) NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given // Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the // Request. If the function returns a non-nil error, the
// request is aborted with the provided error. // request is aborted with the provided error.
@ -71,6 +75,17 @@ type Dialer struct {
// do not limit the size of the messages that can be sent or received. // do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the client's requested subprotocols. // Subprotocols specifies the client's requested subprotocols.
Subprotocols []string Subprotocols []string
@ -86,52 +101,13 @@ type Dialer struct {
Jar http.CookieJar Jar http.CookieJar
} }
// Dial creates a new client connection by calling DialContext with a background context.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
return d.DialContext(context.Background(), urlStr, requestHeader)
}
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
// parseURL parses the URL.
//
// This function is a replacement for the standard library url.Parse function.
// In Go 1.4 and earlier, url.Parse loses information from the path.
func parseURL(s string) (*url.URL, error) {
// From the RFC:
//
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
var u url.URL
switch {
case strings.HasPrefix(s, "ws://"):
u.Scheme = "ws"
s = s[len("ws://"):]
case strings.HasPrefix(s, "wss://"):
u.Scheme = "wss"
s = s[len("wss://"):]
default:
return nil, errMalformedURL
}
if i := strings.Index(s, "?"); i >= 0 {
u.RawQuery = s[i+1:]
s = s[:i]
}
if i := strings.Index(s, "/"); i >= 0 {
u.Opaque = s[i:]
s = s[:i]
} else {
u.Opaque = "/"
}
u.Host = s
if strings.Contains(u.Host, "@") {
// Don't bother parsing user information because user information is
// not allowed in websocket URIs.
return nil, errMalformedURL
}
return &u, nil
}
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host hostPort = u.Host
hostNoPort = u.Host hostNoPort = u.Host
@ -150,26 +126,29 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
return hostPort, hostNoPort return hostPort, hostNoPort
} }
// DefaultDialer is a dialer with all fields set to the default zero values. // DefaultDialer is a dialer with all fields set to the default values.
var DefaultDialer = &Dialer{ var DefaultDialer = &Dialer{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
} }
// Dial creates a new client connection. Use requestHeader to specify the // nilDialer is dialer to use when receiver is nil.
var nilDialer = *DefaultDialer
// DialContext creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol // Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
// //
// The context will be used in the request and in the Dialer
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication, // non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not // etcetera. The response body may not contain the entire response and does not
// need to be closed by the application. // need to be closed by the application.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil { if d == nil {
d = &Dialer{ d = &nilDialer
Proxy: http.ProxyFromEnvironment,
}
} }
challengeKey, err := generateChallengeKey() challengeKey, err := generateChallengeKey()
@ -177,7 +156,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
return nil, nil, err return nil, nil, err
} }
u, err := parseURL(urlStr) u, err := url.Parse(urlStr)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -205,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
Header: make(http.Header), Header: make(http.Header),
Host: u.Host, Host: u.Host,
} }
req = req.WithContext(ctx)
// Set the cookies present in the cookie jar of the dialer // Set the cookies present in the cookie jar of the dialer
if d.Jar != nil { if d.Jar != nil {
@ -237,45 +217,83 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
k == "Sec-Websocket-Extensions" || k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
case k == "Sec-Websocket-Protocol":
req.Header["Sec-WebSocket-Protocol"] = vs
default: default:
req.Header[k] = vs req.Header[k] = vs
} }
} }
if d.EnableCompression { if d.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
} }
hostPort, hostNoPort := hostPortNoPort(u) if d.HandshakeTimeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
defer cancel()
}
var proxyURL *url.URL // Get network dial function.
// Check wether the proxy method has been configured var netDial func(network, add string) (net.Conn, error)
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
} else {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
}
}
// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}
// If needed, wrap the dial function to connect through a proxy.
if d.Proxy != nil { if d.Proxy != nil {
proxyURL, err = d.Proxy(req) proxyURL, err := d.Proxy(req)
}
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var targetHostPort string
if proxyURL != nil { if proxyURL != nil {
targetHostPort, _ = hostPortNoPort(proxyURL) dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
} else { if err != nil {
targetHostPort = hostPort return nil, nil, err
}
netDial = dialer.Dial
}
} }
var deadline time.Time hostPort, hostNoPort := hostPortNoPort(u)
if d.HandshakeTimeout != 0 { trace := httptrace.ContextClientTrace(ctx)
deadline = time.Now().Add(d.HandshakeTimeout) if trace != nil && trace.GetConn != nil {
trace.GetConn(hostPort)
} }
netDial := d.NetDial netConn, err := netDial("tcp", hostPort)
if netDial == nil { if trace != nil && trace.GotConn != nil {
netDialer := &net.Dialer{Deadline: deadline} trace.GotConn(httptrace.GotConnInfo{
netDial = netDialer.Dial Conn: netConn,
})
} }
netConn, err := netDial("tcp", targetHostPort)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -286,42 +304,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
}() }()
if err := netConn.SetDeadline(deadline); err != nil {
return nil, nil, err
}
if proxyURL != nil {
connectHeader := make(http.Header)
if user := proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: hostPort},
Host: hostPort,
Header: connectHeader,
}
connectReq.Write(netConn)
// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(netConn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2)
return nil, nil, errors.New(f[1])
}
}
if u.Scheme == "https" { if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig) cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" { if cfg.ServerName == "" {
@ -329,22 +311,31 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
tlsConn := tls.Client(netConn, cfg) tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn netConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
return nil, nil, err var err error
if trace != nil {
err = doHandshakeWithTrace(trace, tlsConn, cfg)
} else {
err = doHandshake(tlsConn, cfg)
} }
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
if err := req.Write(netConn); err != nil { if err := req.Write(netConn); err != nil {
return nil, nil, err return nil, nil, err
} }
if trace != nil && trace.GotFirstResponseByte != nil {
if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
trace.GotFirstResponseByte()
}
}
resp, err := http.ReadResponse(conn.br, req) resp, err := http.ReadResponse(conn.br, req)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -390,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
netConn = nil // to avoid close in defer. netConn = nil // to avoid close in defer.
return conn, resp, nil return conn, resp, nil
} }
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@ -76,7 +76,7 @@ const (
// is UTF-8 encoded text. // is UTF-8 encoded text.
PingMessage = 9 PingMessage = 9
// PongMessage denotes a ping control message. The optional message payload // PongMessage denotes a pong control message. The optional message payload
// is UTF-8 encoded text. // is UTF-8 encoded text.
PongMessage = 10 PongMessage = 10
) )
@ -100,9 +100,8 @@ func (e *netError) Error() string { return e.msg }
func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Temporary() bool { return e.temporary }
func (e *netError) Timeout() bool { return e.timeout } func (e *netError) Timeout() bool { return e.timeout }
// CloseError represents close frame. // CloseError represents a close message.
type CloseError struct { type CloseError struct {
// Code is defined in RFC 6455, section 11.7. // Code is defined in RFC 6455, section 11.7.
Code int Code int
@ -224,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
} }
// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
// interface. The type of the value stored in a pool is not specified.
type BufferPool interface {
// Get gets a value from the pool or returns nil if the pool is empty.
Get() interface{}
// Put adds a value to the pool.
Put(interface{})
}
// writePoolData is the type added to the write buffer pool. This wrapper is
// used to prevent applications from peeking at and depending on the values
// added to the pool.
type writePoolData struct{ buf []byte }
// The Conn type represents a WebSocket connection. // The Conn type represents a WebSocket connection.
type Conn struct { type Conn struct {
conn net.Conn conn net.Conn
@ -233,6 +246,8 @@ type Conn struct {
// Write fields // Write fields
mu chan bool // used as mutex to protect write to conn mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer. writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool
writeBufSize int
writeDeadline time.Time writeDeadline time.Time
writer io.WriteCloser // the current writer returned to the application writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection isWriting bool // for best-effort concurrent write detection
@ -264,64 +279,29 @@ type Conn struct {
newDecompressionReader func(io.Reader) io.ReadCloser newDecompressionReader func(io.Reader) io.ReadCloser
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
}
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
mu := make(chan bool, 1)
mu <- true
var br *bufio.Reader
if readBufferSize == 0 && brw != nil && brw.Reader != nil {
// Reuse the supplied bufio.Reader if the buffer has a useful size.
// This code assumes that peek on a reader returns
// bufio.Reader.buf[:0].
brw.Reader.Reset(conn)
if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
br = brw.Reader
}
}
if br == nil { if br == nil {
if readBufferSize == 0 { if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize readBufferSize = defaultReadBufferSize
} } else if readBufferSize < maxControlFramePayloadSize {
if readBufferSize < maxControlFramePayloadSize { // must be large enough for control frame
readBufferSize = maxControlFramePayloadSize readBufferSize = maxControlFramePayloadSize
} }
br = bufio.NewReaderSize(conn, readBufferSize) br = bufio.NewReaderSize(conn, readBufferSize)
} }
var writeBuf []byte if writeBufferSize <= 0 {
if writeBufferSize == 0 && brw != nil && brw.Writer != nil {
// Use the bufio.Writer's buffer if the buffer has a useful size. This
// code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
brw.Writer.Reset(&wh)
brw.Writer.WriteByte(0)
brw.Flush()
if cap(wh.p) >= maxFrameHeaderSize+256 {
writeBuf = wh.p[:cap(wh.p)]
}
}
if writeBuf == nil {
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize writeBufferSize = defaultWriteBufferSize
} }
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) writeBufferSize += maxFrameHeaderSize
if writeBuf == nil && writeBufferPool == nil {
writeBuf = make([]byte, writeBufferSize)
} }
mu := make(chan bool, 1)
mu <- true
c := &Conn{ c := &Conn{
isServer: isServer, isServer: isServer,
br: br, br: br,
@ -329,6 +309,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
mu: mu, mu: mu,
readFinal: true, readFinal: true,
writeBuf: writeBuf, writeBuf: writeBuf,
writePool: writeBufferPool,
writeBufSize: writeBufferSize,
enableWriteCompression: true, enableWriteCompression: true,
compressionLevel: defaultCompressionLevel, compressionLevel: defaultCompressionLevel,
} }
@ -343,7 +325,8 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol return c.subprotocol
} }
// Close closes the underlying network connection without sending or waiting for a close frame. // Close closes the underlying network connection without sending or waiting
// for a close message.
func (c *Conn) Close() error { func (c *Conn) Close() error {
return c.conn.Close() return c.conn.Close()
} }
@ -370,7 +353,16 @@ func (c *Conn) writeFatal(err error) error {
return err return err
} }
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
<-c.mu <-c.mu
defer func() { c.mu <- true }() defer func() { c.mu <- true }()
@ -382,15 +374,14 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
} }
c.conn.SetWriteDeadline(deadline) c.conn.SetWriteDeadline(deadline)
for _, buf := range bufs { if len(buf1) == 0 {
if len(buf) > 0 { _, err = c.conn.Write(buf0)
_, err := c.conn.Write(buf) } else {
err = c.writeBufs(buf0, buf1)
}
if err != nil { if err != nil {
return c.writeFatal(err) return c.writeFatal(err)
} }
}
}
if frameType == CloseMessage { if frameType == CloseMessage {
c.writeFatal(ErrCloseSent) c.writeFatal(ErrCloseSent)
} }
@ -476,14 +467,29 @@ func (c *Conn) prepWrite(messageType int) error {
c.writeErrMu.Lock() c.writeErrMu.Lock()
err := c.writeErr err := c.writeErr
c.writeErrMu.Unlock() c.writeErrMu.Unlock()
if err != nil {
return err return err
} }
if c.writeBuf == nil {
wpd, ok := c.writePool.Get().(writePoolData)
if ok {
c.writeBuf = wpd.buf
} else {
c.writeBuf = make([]byte, c.writeBufSize)
}
}
return nil
}
// NextWriter returns a writer for the next message to send. The writer's Close // NextWriter returns a writer for the next message to send. The writer's Close
// method flushes the complete message to the network. // method flushes the complete message to the network.
// //
// There can be at most one open writer on a connection. NextWriter closes the // There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so. // previous writer if the application has not already done so.
//
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if err := c.prepWrite(messageType); err != nil { if err := c.prepWrite(messageType); err != nil {
return nil, err return nil, err
@ -599,6 +605,10 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
if final { if final {
c.writer = nil c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
return nil return nil
} }
@ -764,7 +774,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// Read methods // Read methods
func (c *Conn) advanceFrame() (int, error) { func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
@ -1033,7 +1042,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
} }
// SetReadLimit sets the maximum size for a message read from the peer. If a // SetReadLimit sets the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer // message exceeds the limit, the connection sends a close message to the peer
// and returns ErrReadLimit to the application. // and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) { func (c *Conn) SetReadLimit(limit int64) {
c.readLimit = limit c.readLimit = limit
@ -1046,24 +1055,22 @@ func (c *Conn) CloseHandler() func(code int, text string) error {
// SetCloseHandler sets the handler for close messages received from the peer. // SetCloseHandler sets the handler for close messages received from the peer.
// The code argument to h is the received close code or CloseNoStatusReceived // The code argument to h is the received close code or CloseNoStatusReceived
// if the close message is empty. The default close handler sends a close frame // if the close message is empty. The default close handler sends a close
// back to the peer. // message back to the peer.
// //
// The application must read the connection to process close messages as // The handler function is called from the NextReader, ReadMessage and message
// described in the section on Control Frames above. // reader Read methods. The application must read the connection to process
// close messages as described in the section on Control Messages above.
// //
// The connection read methods return a CloseError when a close frame is // The connection read methods return a CloseError when a close message is
// received. Most applications should handle close messages as part of their // received. Most applications should handle close messages as part of their
// normal error handling. Applications should only set a close handler when the // normal error handling. Applications should only set a close handler when the
// application must perform some action before sending a close frame back to // application must perform some action before sending a close message back to
// the peer. // the peer.
func (c *Conn) SetCloseHandler(h func(code int, text string) error) { func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil { if h == nil {
h = func(code int, text string) error { h = func(code int, text string) error {
message := []byte{} message := FormatCloseMessage(code, "")
if code != CloseNoStatusReceived {
message = FormatCloseMessage(code, "")
}
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
return nil return nil
} }
@ -1077,11 +1084,12 @@ func (c *Conn) PingHandler() func(appData string) error {
} }
// SetPingHandler sets the handler for ping messages received from the peer. // SetPingHandler sets the handler for ping messages received from the peer.
// The appData argument to h is the PING frame application data. The default // The appData argument to h is the PING message application data. The default
// ping handler sends a pong to the peer. // ping handler sends a pong to the peer.
// //
// The application must read the connection to process ping messages as // The handler function is called from the NextReader, ReadMessage and message
// described in the section on Control Frames above. // reader Read methods. The application must read the connection to process
// ping messages as described in the section on Control Messages above.
func (c *Conn) SetPingHandler(h func(appData string) error) { func (c *Conn) SetPingHandler(h func(appData string) error) {
if h == nil { if h == nil {
h = func(message string) error { h = func(message string) error {
@ -1103,11 +1111,12 @@ func (c *Conn) PongHandler() func(appData string) error {
} }
// SetPongHandler sets the handler for pong messages received from the peer. // SetPongHandler sets the handler for pong messages received from the peer.
// The appData argument to h is the PONG frame application data. The default // The appData argument to h is the PONG message application data. The default
// pong handler does nothing. // pong handler does nothing.
// //
// The application must read the connection to process ping messages as // The handler function is called from the NextReader, ReadMessage and message
// described in the section on Control Frames above. // reader Read methods. The application must read the connection to process
// pong messages as described in the section on Control Messages above.
func (c *Conn) SetPongHandler(h func(appData string) error) { func (c *Conn) SetPongHandler(h func(appData string) error) {
if h == nil { if h == nil {
h = func(string) error { return nil } h = func(string) error { return nil }
@ -1141,7 +1150,14 @@ func (c *Conn) SetCompressionLevel(level int) error {
} }
// FormatCloseMessage formats closeCode and text as a WebSocket close message. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
// An empty message is returned for code CloseNoStatusReceived.
func FormatCloseMessage(closeCode int, text string) []byte { func FormatCloseMessage(closeCode int, text string) []byte {
if closeCode == CloseNoStatusReceived {
// Return empty message because it's illegal to send
// CloseNoStatusReceived. Return non-nil value in case application
// checks for nil.
return []byte{}
}
buf := make([]byte, 2+len(text)) buf := make([]byte, 2+len(text))
binary.BigEndian.PutUint16(buf, uint16(closeCode)) binary.BigEndian.PutUint16(buf, uint16(closeCode))
copy(buf[2:], text) copy(buf[2:], text)

View File

@ -1,21 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
if len(p) > 0 {
// advance over the bytes just read
io.ReadFull(c.br, p)
}
return p, err
}

View File

@ -2,17 +2,14 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build go1.5 // +build go1.8
package websocket package websocket
import "io" import "net"
func (c *Conn) read(n int) ([]byte, error) { func (c *Conn) writeBufs(bufs ...[]byte) error {
p, err := c.br.Peek(n) b := net.Buffers(bufs)
if err == io.EOF { _, err := b.WriteTo(c.conn)
err = errUnexpectedEOF return err
}
c.br.Discard(len(p))
return p, err
} }

View File

@ -0,0 +1,18 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
func (c *Conn) writeBufs(bufs ...[]byte) error {
for _, buf := range bufs {
if len(buf) > 0 {
if _, err := c.conn.Write(buf); err != nil {
return err
}
}
}
return nil
}

View File

@ -6,9 +6,8 @@
// //
// Overview // Overview
// //
// The Conn type represents a WebSocket connection. A server application uses // The Conn type represents a WebSocket connection. A server application calls
// the Upgrade function from an Upgrader object with a HTTP request handler // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
// to get a pointer to a Conn:
// //
// var upgrader = websocket.Upgrader{ // var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024, // ReadBufferSize: 1024,
@ -31,10 +30,12 @@
// for { // for {
// messageType, p, err := conn.ReadMessage() // messageType, p, err := conn.ReadMessage()
// if err != nil { // if err != nil {
// log.Println(err)
// return // return
// } // }
// if err = conn.WriteMessage(messageType, p); err != nil { // if err := conn.WriteMessage(messageType, p); err != nil {
// return err // log.Println(err)
// return
// } // }
// } // }
// //
@ -85,20 +86,26 @@
// and pong. Call the connection WriteControl, WriteMessage or NextWriter // and pong. Call the connection WriteControl, WriteMessage or NextWriter
// methods to send a control message to the peer. // methods to send a control message to the peer.
// //
// Connections handle received close messages by sending a close message to the // Connections handle received close messages by calling the handler function
// peer and returning a *CloseError from the the NextReader, ReadMessage or the // set with the SetCloseHandler method and by returning a *CloseError from the
// message Read method. // NextReader, ReadMessage or the message Read method. The default close
// handler sends a close message to the peer.
// //
// Connections handle received ping and pong messages by invoking callback // Connections handle received ping messages by calling the handler function
// functions set with SetPingHandler and SetPongHandler methods. The callback // set with the SetPingHandler method. The default ping handler sends a pong
// functions are called from the NextReader, ReadMessage and the message Read // message to the peer.
// methods.
// //
// The default ping handler sends a pong to the peer. The application's reading // Connections handle received pong messages by calling the handler function
// goroutine can block for a short time while the handler writes the pong data // set with the SetPongHandler method. The default pong handler does nothing.
// to the connection. // If an application sends ping messages, then the application should set a
// pong handler to receive the corresponding pong.
// //
// The application must read the connection to process ping, pong and close // The control message handler functions are called from the NextReader,
// ReadMessage and message reader Read methods. The default close and ping
// handlers can block these methods for a short time when the handler writes to
// the connection.
//
// The application must read the connection to process close, ping and pong
// messages sent from the peer. If the application is not otherwise interested // messages sent from the peer. If the application is not otherwise interested
// in messages from the peer, then the application should start a goroutine to // in messages from the peer, then the application should start a goroutine to
// read and discard messages from the peer. A simple example is: // read and discard messages from the peer. A simple example is:
@ -137,19 +144,12 @@
// method fails the WebSocket handshake with HTTP status 403. // method fails the WebSocket handshake with HTTP status 403.
// //
// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
// the handshake if the Origin request header is present and not equal to the // the handshake if the Origin request header is present and the Origin host is
// Host request header. // not equal to the Host request header.
// //
// An application can allow connections from any origin by specifying a // The deprecated package-level Upgrade function does not perform origin
// function that always returns true: // checking. The application is responsible for checking the Origin header
// // before calling the Upgrade function.
// var upgrader = websocket.Upgrader{
// CheckOrigin: func(r *http.Request) bool { return true },
// }
//
// The deprecated Upgrade function does not enforce an origin policy. It's the
// application's responsibility to check the Origin header before calling
// Upgrade.
// //
// Compression EXPERIMENTAL // Compression EXPERIMENTAL
// //

View File

@ -9,12 +9,14 @@ import (
"io" "io"
) )
// WriteJSON is deprecated, use c.WriteJSON instead. // WriteJSON writes the JSON encoding of v as a message.
//
// Deprecated: Use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error { func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v) return c.WriteJSON(v)
} }
// WriteJSON writes the JSON encoding of v to the connection. // WriteJSON writes the JSON encoding of v as a message.
// //
// See the documentation for encoding/json Marshal for details about the // See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON. // conversion of Go values to JSON.
@ -31,7 +33,10 @@ func (c *Conn) WriteJSON(v interface{}) error {
return err2 return err2
} }
// ReadJSON is deprecated, use c.ReadJSON instead. // ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// Deprecated: Use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error { func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v) return c.ReadJSON(v)
} }

View File

@ -11,7 +11,6 @@ import "unsafe"
const wordSize = int(unsafe.Sizeof(uintptr(0))) const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int { func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers. // Mask one byte at a time for small buffers.
if len(b) < 2*wordSize { if len(b) < 2*wordSize {
for i := range b { for i := range b {

View File

@ -19,7 +19,6 @@ import (
type PreparedMessage struct { type PreparedMessage struct {
messageType int messageType int
data []byte data []byte
err error
mu sync.Mutex mu sync.Mutex
frames map[prepareKey]*preparedFrame frames map[prepareKey]*preparedFrame
} }

77
vendor/github.com/gorilla/websocket/proxy.go generated vendored Normal file
View File

@ -0,0 +1,77 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"encoding/base64"
"errors"
"net"
"net/http"
"net/url"
"strings"
)
type netDialerFunc func(network, addr string) (net.Conn, error)
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
}
func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil
})
}
type httpProxyDialer struct {
proxyURL *url.URL
fowardDial func(network, addr string) (net.Conn, error)
}
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.fowardDial(network, hostPort)
if err != nil {
return nil, err
}
connectHeader := make(http.Header)
if user := hpd.proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: addr},
Host: addr,
Header: connectHeader,
}
if err := connectReq.Write(conn); err != nil {
conn.Close()
return nil, err
}
// Read response. It's OK to use and discard buffered reader here becaue
// the remote server does not speak until spoken to.
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
conn.Close()
return nil, err
}
if resp.StatusCode != 200 {
conn.Close()
f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1])
}
return conn, nil
}

View File

@ -7,7 +7,7 @@ package websocket
import ( import (
"bufio" "bufio"
"errors" "errors"
"net" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -33,10 +33,23 @@ type Upgrader struct {
// or received. // or received.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the server's supported protocols in order of // Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a // preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol // subprotocol by selecting the first match in this list with a protocol
// requested by the client. // requested by the client. If there's no match, then no protocol is
// negotiated (the Sec-Websocket-Protocol header is not included in the
// handshake response).
Subprotocols []string Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error // Error specifies the function for generating HTTP error responses. If Error
@ -44,8 +57,12 @@ type Upgrader struct {
Error func(w http.ResponseWriter, r *http.Request, status int, reason error) Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If // CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must not be set or // CheckOrigin is nil, then a safe default is used: return false if the
// must match the host of the request. // Origin request header is present and the origin host is not equal to
// request Host header.
//
// A CheckOrigin function should carefully validate the request origin to
// prevent cross-site request forgery.
CheckOrigin func(r *http.Request) bool CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per // EnableCompression specify if the server should attempt to negotiate per
@ -76,7 +93,7 @@ func checkSameOrigin(r *http.Request) bool {
if err != nil { if err != nil {
return false return false
} }
return u.Host == r.Host return equalASCIIFold(u.Host, r.Host)
} }
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
@ -99,42 +116,44 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// //
// The responseHeader is included in the response to the client's upgrade // The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the // request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol). // application negotiated subprotocol (Sec-WebSocket-Protocol).
// //
// If the upgrade fails, then Upgrade replies to the client with an HTTP error // If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response. // response.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if r.Method != "GET" { const badHandshake = "websocket: the client is not using the websocket protocol: "
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET")
}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")
}
if !tokenListContainsValue(r.Header, "Connection", "upgrade") { if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header") return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
} }
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header") return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
}
if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
} }
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
} }
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
}
checkOrigin := u.CheckOrigin checkOrigin := u.CheckOrigin
if checkOrigin == nil { if checkOrigin == nil {
checkOrigin = checkSameOrigin checkOrigin = checkSameOrigin
} }
if !checkOrigin(r) { if !checkOrigin(r) {
return u.returnError(w, r, http.StatusForbidden, "websocket: 'Origin' header value not allowed") return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
} }
challengeKey := r.Header.Get("Sec-Websocket-Key") challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" { if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank") return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
} }
subprotocol := u.selectSubprotocol(r, responseHeader) subprotocol := u.selectSubprotocol(r, responseHeader)
@ -151,17 +170,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
var (
netConn net.Conn
err error
)
h, ok := w.(http.Hijacker) h, ok := w.(http.Hijacker)
if !ok { if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
} }
var brw *bufio.ReadWriter var brw *bufio.ReadWriter
netConn, brw, err = h.Hijack() netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError, err.Error())
} }
@ -171,7 +185,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return nil, errors.New("websocket: client sent data before handshake is complete") return nil, errors.New("websocket: client sent data before handshake is complete")
} }
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) var br *bufio.Reader
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
// Reuse hijacked buffered reader as connection reader.
br = brw.Reader
}
buf := bufioWriterBuffer(netConn, brw.Writer)
var writeBuf []byte
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
// Reuse hijacked write buffer as connection buffer.
writeBuf = buf
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
c.subprotocol = subprotocol c.subprotocol = subprotocol
if compress { if compress {
@ -179,17 +207,23 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.newDecompressionReader = decompressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover
} }
p := c.writeBuf[:0] // Use larger of hijacked buffer and connection write buffer for header.
p := buf
if len(c.writeBuf) > len(p) {
p = c.writeBuf
}
p = p[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...) p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
if c.subprotocol != "" { if c.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...) p = append(p, "Sec-WebSocket-Protocol: "...)
p = append(p, c.subprotocol...) p = append(p, c.subprotocol...)
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
} }
if compress { if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
} }
for k, vs := range responseHeader { for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" { if k == "Sec-Websocket-Protocol" {
@ -230,13 +264,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
// Upgrade upgrades the HTTP server connection to the WebSocket protocol. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
// //
// This function is deprecated, use websocket.Upgrader instead. // Deprecated: Use websocket.Upgrader instead.
// //
// The application is responsible for checking the request origin before // Upgrade does not perform origin checking. The application is responsible for
// calling Upgrade. An example implementation of the same origin policy is: // checking the Origin header before calling Upgrade. An example implementation
// of the same origin policy check is:
// //
// if req.Header.Get("Origin") != "http://"+req.Host { // if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403) // http.Error(w, "Origin not allowed", http.StatusForbidden)
// return // return
// } // }
// //
@ -289,3 +324,40 @@ func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") && return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket") tokenListContainsValue(r.Header, "Upgrade", "websocket")
} }
// bufioReaderSize size returns the size of a bufio.Reader.
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
// TODO: Use bufio.Reader.Size() after Go 1.10
br.Reset(originalReader)
if p, err := br.Peek(0); err == nil {
return cap(p)
}
return 0
}
// writeHook is an io.Writer that records the last slice passed to it vio
// io.Writer.Write.
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// This code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)
return wh.p[:cap(wh.p)]
}

19
vendor/github.com/gorilla/websocket/trace.go generated vendored Normal file
View File

@ -0,0 +1,19 @@
// +build go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
if trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(tlsConn, cfg)
if trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
return err
}

12
vendor/github.com/gorilla/websocket/trace_17.go generated vendored Normal file
View File

@ -0,0 +1,12 @@
// +build !go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
return doHandshake(tlsConn, cfg)
}

View File

@ -11,6 +11,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"unicode/utf8"
) )
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
@ -111,14 +112,14 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
case escape: case escape:
escape = false escape = false
p[j] = b p[j] = b
j += 1 j++
case b == '\\': case b == '\\':
escape = true escape = true
case b == '"': case b == '"':
return string(p[:j]), s[i+1:] return string(p[:j]), s[i+1:]
default: default:
p[j] = b p[j] = b
j += 1 j++
} }
} }
return "", "" return "", ""
@ -127,8 +128,31 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
return "", "" return "", ""
} }
// equalASCIIFold returns true if s is equal to t with ASCII case folding.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
// tokenListContainsValue returns true if the 1#token header with the given // tokenListContainsValue returns true if the 1#token header with the given
// name contains token. // name contains a token equal to value with ASCII case folding.
func tokenListContainsValue(header http.Header, name string, value string) bool { func tokenListContainsValue(header http.Header, name string, value string) bool {
headers: headers:
for _, s := range header[name] { for _, s := range header[name] {
@ -142,7 +166,7 @@ headers:
if s != "" && s[0] != ',' { if s != "" && s[0] != ',' {
continue headers continue headers
} }
if strings.EqualFold(t, value) { if equalASCIIFold(t, value) {
return true return true
} }
if s == "" { if s == "" {
@ -154,9 +178,8 @@ headers:
return false return false
} }
// parseExtensiosn parses WebSocket extensions from a header. // parseExtensions parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string { func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455: // From RFC 6455:
// //
// Sec-WebSocket-Extensions = extension-list // Sec-WebSocket-Extensions = extension-list

473
vendor/github.com/gorilla/websocket/x_net_proxy.go generated vendored Normal file
View File

@ -0,0 +1,473 @@
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
// Package proxy provides support for a variety of protocols to proxy network
// data.
//
package websocket
import (
"errors"
"io"
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
)
type proxy_direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var proxy_Direct = proxy_direct{}
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// A PerHost directs connections to a default Dialer unless the host name
// requested matches one of a number of exceptions.
type proxy_PerHost struct {
def, bypass proxy_Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
return &proxy_PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone ".example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifying hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a host name
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *proxy_PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *proxy_PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match.
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *proxy_PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a host name that will use the bypass proxy.
func (p *proxy_PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}
// A Dialer is a means to establish a connection.
type proxy_Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type proxy_Auth struct {
User, Password string
}
// FromEnvironment returns the dialer specified by the proxy related variables in
// the environment.
func proxy_FromEnvironment() proxy_Dialer {
allProxy := proxy_allProxyEnv.Get()
if len(allProxy) == 0 {
return proxy_Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return proxy_Direct
}
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
if err != nil {
return proxy_Direct
}
noProxy := proxy_noProxyEnv.Get()
if len(noProxy) == 0 {
return proxy
}
perHost := proxy_NewPerHost(proxy, proxy_Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
if proxy_proxySchemes == nil {
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
}
proxy_proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
var auth *proxy_Auth
if u.User != nil {
auth = new(proxy_Auth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5":
return proxy_SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxy_proxySchemes != nil {
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
var (
proxy_allProxyEnv = &proxy_envOnce{
names: []string{"ALL_PROXY", "all_proxy"},
}
proxy_noProxyEnv = &proxy_envOnce{
names: []string{"NO_PROXY", "no_proxy"},
}
)
// envOnce looks up an environment variable (optionally by multiple
// names) once. It mitigates expensive lookups on some platforms
// (e.g. Windows).
// (Borrowed from net/http/transport.go)
type proxy_envOnce struct {
names []string
once sync.Once
val string
}
func (e *proxy_envOnce) Get() string {
e.once.Do(e.init)
return e.val
}
func (e *proxy_envOnce) init() {
for _, n := range e.names {
e.val = os.Getenv(n)
if e.val != "" {
return
}
}
}
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928 and RFC 1929.
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
s := &proxy_socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type proxy_socks5 struct {
user, password string
network, addr string
forward proxy_Dialer
}
const proxy_socks5Version = 5
const (
proxy_socks5AuthNone = 0
proxy_socks5AuthPassword = 2
)
const proxy_socks5Connect = 1
const (
proxy_socks5IP4 = 1
proxy_socks5Domain = 3
proxy_socks5IP6 = 4
)
var proxy_socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
if err := s.connect(conn, addr); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// connect takes an existing connection to a socks5 proxy server,
// and commands the server to extend that connection to target,
// which must be a canonical address with a host and port.
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, proxy_socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
} else {
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
}
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
// See RFC 1929
if buf[1] == proxy_socks5AuthPassword {
buf = buf[:0]
buf = append(buf, 1 /* password protocol version */)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, proxy_socks5IP4)
ip = ip4
} else {
buf = append(buf, proxy_socks5IP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return errors.New("proxy: destination host name too long: " + host)
}
buf = append(buf, proxy_socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(proxy_socks5Errors) {
failure = proxy_socks5Errors[buf[1]]
}
if len(failure) > 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case proxy_socks5IP4:
bytesToDiscard = net.IPv4len
case proxy_socks5IP6:
bytesToDiscard = net.IPv6len
case proxy_socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err := io.ReadFull(conn, buf); err != nil {
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
return nil
}

View File

@ -1,47 +0,0 @@
package discordgo
import (
"fmt"
"strings"
)
// A User stores all data for an individual Discord user.
type User struct {
ID string `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Avatar string `json:"avatar"`
Discriminator string `json:"discriminator"`
Token string `json:"token"`
Verified bool `json:"verified"`
MFAEnabled bool `json:"mfa_enabled"`
Bot bool `json:"bot"`
}
// String returns a unique identifier of the form username#discriminator
func (u *User) String() string {
return fmt.Sprintf("%s#%s", u.Username, u.Discriminator)
}
// Mention return a string which mentions the user
func (u *User) Mention() string {
return fmt.Sprintf("<@%s>", u.ID)
}
// AvatarURL returns a URL to the user's avatar.
// size: The size of the user's avatar as a power of two
// if size is an empty string, no size parameter will
// be added to the URL.
func (u *User) AvatarURL(size string) string {
var URL string
if strings.HasPrefix(u.Avatar, "a_") {
URL = EndpointUserAvatarAnimated(u.ID, u.Avatar)
} else {
URL = EndpointUserAvatar(u.ID, u.Avatar)
}
if size != "" {
return URL + "?size=" + size
}
return URL
}

View File

@ -14,7 +14,6 @@
package acme package acme
import ( import (
"bytes"
"context" "context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
@ -23,6 +22,8 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -33,14 +34,26 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"net/http" "net/http"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
const (
// LetsEncryptURL is the Directory endpoint of Let's Encrypt CA. // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
const LetsEncryptURL = "https://acme-v01.api.letsencrypt.org/directory" LetsEncryptURL = "https://acme-v01.api.letsencrypt.org/directory"
// ALPNProto is the ALPN protocol name used by a CA server when validating
// tls-alpn-01 challenges.
//
// Package users must ensure their servers can negotiate the ACME ALPN in
// order for tls-alpn-01 challenge verifications to succeed.
// See the crypto/tls package's Config.NextProtos field.
ALPNProto = "acme-tls/1"
)
// idPeACMEIdentifierV1 is the OID for the ACME extension for the TLS-ALPN challenge.
var idPeACMEIdentifierV1 = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
const ( const (
maxChainLen = 5 // max depth and breadth of a certificate chain maxChainLen = 5 // max depth and breadth of a certificate chain
@ -76,6 +89,22 @@ type Client struct {
// will have no effect. // will have no effect.
DirectoryURL string DirectoryURL string
// RetryBackoff computes the duration after which the nth retry of a failed request
// should occur. The value of n for the first call on failure is 1.
// The values of r and resp are the request and response of the last failed attempt.
// If the returned value is negative or zero, no more retries are done and an error
// is returned to the caller of the original method.
//
// Requests which result in a 4xx client error are not retried,
// except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests.
//
// If RetryBackoff is nil, a truncated exponential backoff algorithm
// with the ceiling of 10 seconds is used, where each subsequent retry n
// is done after either ("Retry-After" + jitter) or (2^n seconds + jitter),
// preferring the former if "Retry-After" header is found in the resp.
// The jitter is a random value up to 1 second.
RetryBackoff func(n int, r *http.Request, resp *http.Response) time.Duration
dirMu sync.Mutex // guards writes to dir dirMu sync.Mutex // guards writes to dir
dir *Directory // cached result of Client's Discover method dir *Directory // cached result of Client's Discover method
@ -99,15 +128,12 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
if dirURL == "" { if dirURL == "" {
dirURL = LetsEncryptURL dirURL = LetsEncryptURL
} }
res, err := c.get(ctx, dirURL) res, err := c.get(ctx, dirURL, wantStatus(http.StatusOK))
if err != nil { if err != nil {
return Directory{}, err return Directory{}, err
} }
defer res.Body.Close() defer res.Body.Close()
c.addNonce(res.Header) c.addNonce(res.Header)
if res.StatusCode != http.StatusOK {
return Directory{}, responseError(res)
}
var v struct { var v struct {
Reg string `json:"new-reg"` Reg string `json:"new-reg"`
@ -166,14 +192,11 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
req.NotAfter = now.Add(exp).Format(time.RFC3339) req.NotAfter = now.Add(exp).Format(time.RFC3339)
} }
res, err := c.retryPostJWS(ctx, c.Key, c.dir.CertURL, req) res, err := c.post(ctx, c.Key, c.dir.CertURL, req, wantStatus(http.StatusCreated))
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return nil, "", responseError(res)
}
curl := res.Header.Get("Location") // cert permanent URL curl := res.Header.Get("Location") // cert permanent URL
if res.ContentLength == 0 { if res.ContentLength == 0 {
@ -196,27 +219,12 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
// Callers are encouraged to parse the returned value to ensure the certificate is valid // Callers are encouraged to parse the returned value to ensure the certificate is valid
// and has expected features. // and has expected features.
func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) { func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
for { res, err := c.get(ctx, url, wantStatus(http.StatusOK))
res, err := c.get(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
return c.responseCert(ctx, res, bundle) return c.responseCert(ctx, res, bundle)
} }
if res.StatusCode > 299 {
return nil, responseError(res)
}
d := retryAfter(res.Header.Get("Retry-After"), 3*time.Second)
select {
case <-time.After(d):
// retry
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
// RevokeCert revokes a previously issued certificate cert, provided in DER format. // RevokeCert revokes a previously issued certificate cert, provided in DER format.
// //
@ -241,14 +249,11 @@ func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte,
if key == nil { if key == nil {
key = c.Key key = c.Key
} }
res, err := c.retryPostJWS(ctx, key, c.dir.RevokeURL, body) res, err := c.post(ctx, key, c.dir.RevokeURL, body, wantStatus(http.StatusOK))
if err != nil { if err != nil {
return err return err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return responseError(res)
}
return nil return nil
} }
@ -329,14 +334,11 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
Resource: "new-authz", Resource: "new-authz",
Identifier: authzID{Type: "dns", Value: domain}, Identifier: authzID{Type: "dns", Value: domain},
} }
res, err := c.retryPostJWS(ctx, c.Key, c.dir.AuthzURL, req) res, err := c.post(ctx, c.Key, c.dir.AuthzURL, req, wantStatus(http.StatusCreated))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != http.StatusCreated {
return nil, responseError(res)
}
var v wireAuthz var v wireAuthz
if err := json.NewDecoder(res.Body).Decode(&v); err != nil { if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
@ -353,14 +355,11 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
// If a caller needs to poll an authorization until its status is final, // If a caller needs to poll an authorization until its status is final,
// see the WaitAuthorization method. // see the WaitAuthorization method.
func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) { func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
res, err := c.get(ctx, url) res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
return nil, responseError(res)
}
var v wireAuthz var v wireAuthz
if err := json.NewDecoder(res.Body).Decode(&v); err != nil { if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
return nil, fmt.Errorf("acme: invalid response: %v", err) return nil, fmt.Errorf("acme: invalid response: %v", err)
@ -387,14 +386,11 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
Status: "deactivated", Status: "deactivated",
Delete: true, Delete: true,
} }
res, err := c.retryPostJWS(ctx, c.Key, url, req) res, err := c.post(ctx, c.Key, url, req, wantStatus(http.StatusOK))
if err != nil { if err != nil {
return err return err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return responseError(res)
}
return nil return nil
} }
@ -406,44 +402,42 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
// In all other cases WaitAuthorization returns an error. // In all other cases WaitAuthorization returns an error.
// If the Status is StatusInvalid, the returned error is of type *AuthorizationError. // If the Status is StatusInvalid, the returned error is of type *AuthorizationError.
func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) { func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) {
sleep := sleeper(ctx)
for { for {
res, err := c.get(ctx, url) res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if res.StatusCode >= 400 && res.StatusCode <= 499 {
// Non-retriable error. For instance, Let's Encrypt may return 404 Not Found
// when requesting an expired authorization.
defer res.Body.Close()
return nil, responseError(res)
}
retry := res.Header.Get("Retry-After")
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
res.Body.Close()
if err := sleep(retry, 1); err != nil {
return nil, err
}
continue
}
var raw wireAuthz var raw wireAuthz
err = json.NewDecoder(res.Body).Decode(&raw) err = json.NewDecoder(res.Body).Decode(&raw)
res.Body.Close() res.Body.Close()
if err != nil { switch {
if err := sleep(retry, 0); err != nil { case err != nil:
return nil, err // Skip and retry.
} case raw.Status == StatusValid:
continue
}
if raw.Status == StatusValid {
return raw.authorization(url), nil return raw.authorization(url), nil
} case raw.Status == StatusInvalid:
if raw.Status == StatusInvalid {
return nil, raw.error(url) return nil, raw.error(url)
} }
if err := sleep(retry, 0); err != nil {
return nil, err // Exponential backoff is implemented in c.get above.
// This is just to prevent continuously hitting the CA
// while waiting for a final authorization status.
d := retryAfter(res.Header.Get("Retry-After"))
if d == 0 {
// Given that the fastest challenges TLS-SNI and HTTP-01
// require a CA to make at least 1 network round trip
// and most likely persist a challenge state,
// this default delay seems reasonable.
d = time.Second
}
t := time.NewTimer(d)
select {
case <-ctx.Done():
t.Stop()
return nil, ctx.Err()
case <-t.C:
// Retry.
} }
} }
} }
@ -452,14 +446,11 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
// //
// A client typically polls a challenge status using this method. // A client typically polls a challenge status using this method.
func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) { func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
res, err := c.get(ctx, url) res, err := c.get(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
return nil, responseError(res)
}
v := wireChallenge{URI: url} v := wireChallenge{URI: url}
if err := json.NewDecoder(res.Body).Decode(&v); err != nil { if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
return nil, fmt.Errorf("acme: invalid response: %v", err) return nil, fmt.Errorf("acme: invalid response: %v", err)
@ -486,16 +477,14 @@ func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error
Type: chal.Type, Type: chal.Type,
Auth: auth, Auth: auth,
} }
res, err := c.retryPostJWS(ctx, c.Key, chal.URI, req) res, err := c.post(ctx, c.Key, chal.URI, req, wantStatus(
http.StatusOK, // according to the spec
http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md)
))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
// Note: the protocol specifies 200 as the expected response code, but
// letsencrypt seems to be returning 202.
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
return nil, responseError(res)
}
var v wireChallenge var v wireChallenge
if err := json.NewDecoder(res.Body).Decode(&v); err != nil { if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
@ -552,7 +541,7 @@ func (c *Client) HTTP01ChallengePath(token string) string {
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. // If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
// //
// The returned certificate is valid for the next 24 hours and must be presented only when // The returned certificate is valid for the next 24 hours and must be presented only when
// the server name of the client hello matches exactly the returned name value. // the server name of the TLS ClientHello matches exactly the returned name value.
func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) {
ka, err := keyAuth(c.Key.Public(), token) ka, err := keyAuth(c.Key.Public(), token)
if err != nil { if err != nil {
@ -579,7 +568,7 @@ func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tl
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. // If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
// //
// The returned certificate is valid for the next 24 hours and must be presented only when // The returned certificate is valid for the next 24 hours and must be presented only when
// the server name in the client hello matches exactly the returned name value. // the server name in the TLS ClientHello matches exactly the returned name value.
func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) {
b := sha256.Sum256([]byte(token)) b := sha256.Sum256([]byte(token))
h := hex.EncodeToString(b[:]) h := hex.EncodeToString(b[:])
@ -600,6 +589,52 @@ func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tl
return cert, sanA, nil return cert, sanA, nil
} }
// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response.
// Servers can present the certificate to validate the challenge and prove control
// over a domain name. For more details on TLS-ALPN-01 see
// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3
//
// The token argument is a Challenge.Token value.
// If a WithKey option is provided, its private part signs the returned cert,
// and the public part is used to specify the signee.
// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve.
//
// The returned certificate is valid for the next 24 hours and must be presented only when
// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol
// has been specified.
func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) {
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
return tls.Certificate{}, err
}
shasum := sha256.Sum256([]byte(ka))
extValue, err := asn1.Marshal(shasum[:])
if err != nil {
return tls.Certificate{}, err
}
acmeExtension := pkix.Extension{
Id: idPeACMEIdentifierV1,
Critical: true,
Value: extValue,
}
tmpl := defaultTLSChallengeCertTemplate()
var newOpt []CertOption
for _, o := range opt {
switch o := o.(type) {
case *certOptTemplate:
t := *(*x509.Certificate)(o) // shallow copy is ok
tmpl = &t
default:
newOpt = append(newOpt, o)
}
}
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension)
newOpt = append(newOpt, WithTemplate(tmpl))
return tlsChallengeCert([]string{domain}, newOpt)
}
// doReg sends all types of registration requests. // doReg sends all types of registration requests.
// The type of request is identified by typ argument, which is a "resource" // The type of request is identified by typ argument, which is a "resource"
// in the ACME spec terms. // in the ACME spec terms.
@ -619,14 +654,15 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
req.Contact = acct.Contact req.Contact = acct.Contact
req.Agreement = acct.AgreedTerms req.Agreement = acct.AgreedTerms
} }
res, err := c.retryPostJWS(ctx, c.Key, url, req) res, err := c.post(ctx, c.Key, url, req, wantStatus(
http.StatusOK, // updates and deletes
http.StatusCreated, // new account creation
http.StatusAccepted, // Let's Encrypt divergent implementation
))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 299 {
return nil, responseError(res)
}
var v struct { var v struct {
Contact []string Contact []string
@ -656,59 +692,6 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
}, nil }, nil
} }
// retryPostJWS will retry calls to postJWS if there is a badNonce error,
// clearing the stored nonces after each error.
// If the response was 4XX-5XX, then responseError is called on the body,
// the body is closed, and the error returned.
func (c *Client) retryPostJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
sleep := sleeper(ctx)
for {
res, err := c.postJWS(ctx, key, url, body)
if err != nil {
return nil, err
}
// handle errors 4XX-5XX with responseError
if res.StatusCode >= 400 && res.StatusCode <= 599 {
err := responseError(res)
res.Body.Close()
// according to spec badNonce is urn:ietf:params:acme:error:badNonce
// however, acme servers in the wild return their version of the error
// https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4
if ae, ok := err.(*Error); ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") {
// clear any nonces that we might've stored that might now be
// considered bad
c.clearNonces()
retry := res.Header.Get("Retry-After")
if err := sleep(retry, 1); err != nil {
return nil, err
}
continue
}
return nil, err
}
return res, nil
}
}
// postJWS signs the body with the given key and POSTs it to the provided url.
// The body argument must be JSON-serializable.
func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
nonce, err := c.popNonce(ctx, url)
if err != nil {
return nil, err
}
b, err := jwsEncodeJSON(body, key, nonce)
if err != nil {
return nil, err
}
res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b))
if err != nil {
return nil, err
}
c.addNonce(res.Header)
return res, nil
}
// popNonce returns a nonce value previously stored with c.addNonce // popNonce returns a nonce value previously stored with c.addNonce
// or fetches a fresh one from the given URL. // or fetches a fresh one from the given URL.
func (c *Client) popNonce(ctx context.Context, url string) (string, error) { func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
@ -749,58 +732,12 @@ func (c *Client) addNonce(h http.Header) {
c.nonces[v] = struct{}{} c.nonces[v] = struct{}{}
} }
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) {
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return nil, err
}
return c.do(ctx, req)
}
func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", urlStr, nil)
if err != nil {
return nil, err
}
return c.do(ctx, req)
}
func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", urlStr, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
return c.do(ctx, req)
}
func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
res, err := c.httpClient().Do(req.WithContext(ctx))
if err != nil {
select {
case <-ctx.Done():
// Prefer the unadorned context error.
// (The acme package had tests assuming this, previously from ctxhttp's
// behavior, predating net/http supporting contexts natively)
// TODO(bradfitz): reconsider this in the future. But for now this
// requires no test updates.
return nil, ctx.Err()
default:
return nil, err
}
}
return res, nil
}
func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) { func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
resp, err := c.head(ctx, url) r, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return "", err
}
resp, err := c.doNoRetry(ctx, r)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -852,24 +789,6 @@ func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bo
return cert, nil return cert, nil
} }
// responseError creates an error of Error type from resp.
func responseError(resp *http.Response) error {
// don't care if ReadAll returns an error:
// json.Unmarshal will fail in that case anyway
b, _ := ioutil.ReadAll(resp.Body)
e := &wireError{Status: resp.StatusCode}
if err := json.Unmarshal(b, e); err != nil {
// this is not a regular error response:
// populate detail with anything we received,
// e.Status will already contain HTTP response code value
e.Detail = string(b)
if e.Detail == "" {
e.Detail = resp.Status
}
}
return e.error(resp.Header)
}
// chainCert fetches CA certificate chain recursively by following "up" links. // chainCert fetches CA certificate chain recursively by following "up" links.
// Each recursive call increments the depth by 1, resulting in an error // Each recursive call increments the depth by 1, resulting in an error
// if the recursion level reaches maxChainLen. // if the recursion level reaches maxChainLen.
@ -880,14 +799,11 @@ func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte
return nil, errors.New("acme: certificate chain is too deep") return nil, errors.New("acme: certificate chain is too deep")
} }
res, err := c.get(ctx, url) res, err := c.get(ctx, url, wantStatus(http.StatusOK))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, responseError(res)
}
b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1)) b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
if err != nil { if err != nil {
return nil, err return nil, err
@ -932,65 +848,6 @@ func linkHeader(h http.Header, rel string) []string {
return links return links
} }
// sleeper returns a function that accepts the Retry-After HTTP header value
// and an increment that's used with backoff to increasingly sleep on
// consecutive calls until the context is done. If the Retry-After header
// cannot be parsed, then backoff is used with a maximum sleep time of 10
// seconds.
func sleeper(ctx context.Context) func(ra string, inc int) error {
var count int
return func(ra string, inc int) error {
count += inc
d := backoff(count, 10*time.Second)
d = retryAfter(ra, d)
wakeup := time.NewTimer(d)
defer wakeup.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-wakeup.C:
return nil
}
}
}
// retryAfter parses a Retry-After HTTP header value,
// trying to convert v into an int (seconds) or use http.ParseTime otherwise.
// It returns d if v cannot be parsed.
func retryAfter(v string, d time.Duration) time.Duration {
if i, err := strconv.Atoi(v); err == nil {
return time.Duration(i) * time.Second
}
t, err := http.ParseTime(v)
if err != nil {
return d
}
return t.Sub(timeNow())
}
// backoff computes a duration after which an n+1 retry iteration should occur
// using truncated exponential backoff algorithm.
//
// The n argument is always bounded between 0 and 30.
// The max argument defines upper bound for the returned value.
func backoff(n int, max time.Duration) time.Duration {
if n < 0 {
n = 0
}
if n > 30 {
n = 30
}
var d time.Duration
if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil {
d = time.Duration(x.Int64()) * time.Millisecond
}
d += time.Duration(1<<uint(n)) * time.Second
if d > max {
return max
}
return d
}
// keyAuth generates a key authorization string for a given token. // keyAuth generates a key authorization string for a given token.
func keyAuth(pub crypto.PublicKey, token string) (string, error) { func keyAuth(pub crypto.PublicKey, token string) (string, error) {
th, err := JWKThumbprint(pub) th, err := JWKThumbprint(pub)
@ -1000,15 +857,25 @@ func keyAuth(pub crypto.PublicKey, token string) (string, error) {
return fmt.Sprintf("%s.%s", token, th), nil return fmt.Sprintf("%s.%s", token, th), nil
} }
// defaultTLSChallengeCertTemplate is a template used to create challenge certs for TLS challenges.
func defaultTLSChallengeCertTemplate() *x509.Certificate {
return &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
}
// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges // tlsChallengeCert creates a temporary certificate for TLS-SNI challenges
// with the given SANs and auto-generated public/private key pair. // with the given SANs and auto-generated public/private key pair.
// The Subject Common Name is set to the first SAN to aid debugging. // The Subject Common Name is set to the first SAN to aid debugging.
// To create a cert with a custom key pair, specify WithKey option. // To create a cert with a custom key pair, specify WithKey option.
func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) { func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
var ( var key crypto.Signer
key crypto.Signer tmpl := defaultTLSChallengeCertTemplate()
tmpl *x509.Certificate
)
for _, o := range opt { for _, o := range opt {
switch o := o.(type) { switch o := o.(type) {
case *certOptKey: case *certOptKey:
@ -1017,7 +884,7 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
} }
key = o.key key = o.key
case *certOptTemplate: case *certOptTemplate:
var t = *(*x509.Certificate)(o) // shallow copy is ok t := *(*x509.Certificate)(o) // shallow copy is ok
tmpl = &t tmpl = &t
default: default:
// package's fault, if we let this happen: // package's fault, if we let this happen:
@ -1030,16 +897,6 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
return tls.Certificate{}, err return tls.Certificate{}, err
} }
} }
if tmpl == nil {
tmpl = &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
BasicConstraintsValid: true,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
}
tmpl.DNSNames = san tmpl.DNSNames = san
if len(san) > 0 { if len(san) > 0 {
tmpl.Subject.CommonName = san[0] tmpl.Subject.CommonName = san[0]

View File

@ -44,7 +44,7 @@ var createCertRetryAfter = time.Minute
var pseudoRand *lockedMathRand var pseudoRand *lockedMathRand
func init() { func init() {
src := mathrand.NewSource(timeNow().UnixNano()) src := mathrand.NewSource(time.Now().UnixNano())
pseudoRand = &lockedMathRand{rnd: mathrand.New(src)} pseudoRand = &lockedMathRand{rnd: mathrand.New(src)}
} }
@ -69,7 +69,7 @@ func HostWhitelist(hosts ...string) HostPolicy {
} }
return func(_ context.Context, host string) error { return func(_ context.Context, host string) error {
if !whitelist[host] { if !whitelist[host] {
return errors.New("acme/autocert: host not configured") return fmt.Errorf("acme/autocert: host %q not configured in HostWhitelist", host)
} }
return nil return nil
} }
@ -81,9 +81,9 @@ func defaultHostPolicy(context.Context, string) error {
} }
// Manager is a stateful certificate manager built on top of acme.Client. // Manager is a stateful certificate manager built on top of acme.Client.
// It obtains and refreshes certificates automatically using "tls-sni-01", // It obtains and refreshes certificates automatically using "tls-alpn-01",
// "tls-sni-02" and "http-01" challenge types, as well as providing them // "tls-sni-01", "tls-sni-02" and "http-01" challenge types,
// to a TLS server via tls.Config. // as well as providing them to a TLS server via tls.Config.
// //
// You must specify a cache implementation, such as DirCache, // You must specify a cache implementation, such as DirCache,
// to reuse obtained certificates across program restarts. // to reuse obtained certificates across program restarts.
@ -98,11 +98,11 @@ type Manager struct {
// To always accept the terms, the callers can use AcceptTOS. // To always accept the terms, the callers can use AcceptTOS.
Prompt func(tosURL string) bool Prompt func(tosURL string) bool
// Cache optionally stores and retrieves previously-obtained certificates. // Cache optionally stores and retrieves previously-obtained certificates
// If nil, certs will only be cached for the lifetime of the Manager. // and other state. If nil, certs will only be cached for the lifetime of
// the Manager. Multiple Managers can share the same Cache.
// //
// Manager passes the Cache certificates data encoded in PEM, with private/public // Using a persistent Cache, such as DirCache, is strongly recommended.
// parts combined in a single Cache.Put call, private key first.
Cache Cache Cache Cache
// HostPolicy controls which domains the Manager will attempt // HostPolicy controls which domains the Manager will attempt
@ -127,8 +127,10 @@ type Manager struct {
// Client is used to perform low-level operations, such as account registration // Client is used to perform low-level operations, such as account registration
// and requesting new certificates. // and requesting new certificates.
//
// If Client is nil, a zero-value acme.Client is used with acme.LetsEncryptURL // If Client is nil, a zero-value acme.Client is used with acme.LetsEncryptURL
// directory endpoint and a newly-generated ECDSA P-256 key. // as directory endpoint. If the Client.Key is nil, a new ECDSA P-256 key is
// generated and, if Cache is not nil, stored in cache.
// //
// Mutating the field after the first call of GetCertificate method will have no effect. // Mutating the field after the first call of GetCertificate method will have no effect.
Client *acme.Client Client *acme.Client
@ -140,22 +142,30 @@ type Manager struct {
// If the Client's account key is already registered, Email is not used. // If the Client's account key is already registered, Email is not used.
Email string Email string
// ForceRSA makes the Manager generate certificates with 2048-bit RSA keys. // ForceRSA used to make the Manager generate RSA certificates. It is now ignored.
// //
// If false, a default is used. Currently the default // Deprecated: the Manager will request the correct type of certificate based
// is EC-based keys using the P-256 curve. // on what each client supports.
ForceRSA bool ForceRSA bool
// ExtraExtensions are used when generating a new CSR (Certificate Request),
// thus allowing customization of the resulting certificate.
// For instance, TLS Feature Extension (RFC 7633) can be used
// to prevent an OCSP downgrade attack.
//
// The field value is passed to crypto/x509.CreateCertificateRequest
// in the template's ExtraExtensions field as is.
ExtraExtensions []pkix.Extension
clientMu sync.Mutex clientMu sync.Mutex
client *acme.Client // initialized by acmeClient method client *acme.Client // initialized by acmeClient method
stateMu sync.Mutex stateMu sync.Mutex
state map[string]*certState // keyed by domain name state map[certKey]*certState
// renewal tracks the set of domains currently running renewal timers. // renewal tracks the set of domains currently running renewal timers.
// It is keyed by domain name.
renewalMu sync.Mutex renewalMu sync.Mutex
renewal map[string]*domainRenewal renewal map[certKey]*domainRenewal
// tokensMu guards the rest of the fields: tryHTTP01, certTokens and httpTokens. // tokensMu guards the rest of the fields: tryHTTP01, certTokens and httpTokens.
tokensMu sync.RWMutex tokensMu sync.RWMutex
@ -167,21 +177,60 @@ type Manager struct {
// to be provisioned. // to be provisioned.
// The entries are stored for the duration of the authorization flow. // The entries are stored for the duration of the authorization flow.
httpTokens map[string][]byte httpTokens map[string][]byte
// certTokens contains temporary certificates for tls-sni challenges // certTokens contains temporary certificates for tls-sni and tls-alpn challenges
// and is keyed by token domain name, which matches server name of ClientHello. // and is keyed by token domain name, which matches server name of ClientHello.
// Keys always have ".acme.invalid" suffix. // Keys always have ".acme.invalid" suffix for tls-sni. Otherwise, they are domain names
// for tls-alpn.
// The entries are stored for the duration of the authorization flow. // The entries are stored for the duration of the authorization flow.
certTokens map[string]*tls.Certificate certTokens map[string]*tls.Certificate
// nowFunc, if not nil, returns the current time. This may be set for
// testing purposes.
nowFunc func() time.Time
}
// certKey is the key by which certificates are tracked in state, renewal and cache.
type certKey struct {
domain string // without trailing dot
isRSA bool // RSA cert for legacy clients (as opposed to default ECDSA)
isToken bool // tls-based challenge token cert; key type is undefined regardless of isRSA
}
func (c certKey) String() string {
if c.isToken {
return c.domain + "+token"
}
if c.isRSA {
return c.domain + "+rsa"
}
return c.domain
}
// TLSConfig creates a new TLS config suitable for net/http.Server servers,
// supporting HTTP/2 and the tls-alpn-01 ACME challenge type.
func (m *Manager) TLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: m.GetCertificate,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
acme.ALPNProto, // enable tls-alpn ACME challenges
},
}
} }
// GetCertificate implements the tls.Config.GetCertificate hook. // GetCertificate implements the tls.Config.GetCertificate hook.
// It provides a TLS certificate for hello.ServerName host, including answering // It provides a TLS certificate for hello.ServerName host, including answering
// *.acme.invalid (TLS-SNI) challenges. All other fields of hello are ignored. // tls-alpn-01 and *.acme.invalid (tls-sni-01 and tls-sni-02) challenges.
// All other fields of hello are ignored.
// //
// If m.HostPolicy is non-nil, GetCertificate calls the policy before requesting // If m.HostPolicy is non-nil, GetCertificate calls the policy before requesting
// a new cert. A non-nil error returned from m.HostPolicy halts TLS negotiation. // a new cert. A non-nil error returned from m.HostPolicy halts TLS negotiation.
// The error is propagated back to the caller of GetCertificate and is user-visible. // The error is propagated back to the caller of GetCertificate and is user-visible.
// This does not affect cached certs. See HostPolicy field description for more details. // This does not affect cached certs. See HostPolicy field description for more details.
//
// If GetCertificate is used directly, instead of via Manager.TLSConfig, package users will
// also have to add acme.ALPNProto to NextProtos for tls-alpn-01, or use HTTPHandler
// for http-01. (The tls-sni-* challenges have been deprecated by popular ACME providers
// due to security issues in the ecosystem.)
func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if m.Prompt == nil { if m.Prompt == nil {
return nil, errors.New("acme/autocert: Manager.Prompt not set") return nil, errors.New("acme/autocert: Manager.Prompt not set")
@ -194,7 +243,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if !strings.Contains(strings.Trim(name, "."), ".") { if !strings.Contains(strings.Trim(name, "."), ".") {
return nil, errors.New("acme/autocert: server name component count invalid") return nil, errors.New("acme/autocert: server name component count invalid")
} }
if strings.ContainsAny(name, `/\`) { if strings.ContainsAny(name, `+/\`) {
return nil, errors.New("acme/autocert: server name contains invalid character") return nil, errors.New("acme/autocert: server name contains invalid character")
} }
@ -203,14 +252,17 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
// check whether this is a token cert requested for TLS-SNI challenge // Check whether this is a token cert requested for TLS-SNI or TLS-ALPN challenge.
if strings.HasSuffix(name, ".acme.invalid") { if wantsTokenCert(hello) {
m.tokensMu.RLock() m.tokensMu.RLock()
defer m.tokensMu.RUnlock() defer m.tokensMu.RUnlock()
// It's ok to use the same token cert key for both tls-sni and tls-alpn
// because there's always at most 1 token cert per on-going domain authorization.
// See m.verify for details.
if cert := m.certTokens[name]; cert != nil { if cert := m.certTokens[name]; cert != nil {
return cert, nil return cert, nil
} }
if cert, err := m.cacheGet(ctx, name); err == nil { if cert, err := m.cacheGet(ctx, certKey{domain: name, isToken: true}); err == nil {
return cert, nil return cert, nil
} }
// TODO: cache error results? // TODO: cache error results?
@ -218,8 +270,11 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
} }
// regular domain // regular domain
name = strings.TrimSuffix(name, ".") // golang.org/issue/18114 ck := certKey{
cert, err := m.cert(ctx, name) domain: strings.TrimSuffix(name, "."), // golang.org/issue/18114
isRSA: !supportsECDSA(hello),
}
cert, err := m.cert(ctx, ck)
if err == nil { if err == nil {
return cert, nil return cert, nil
} }
@ -231,14 +286,71 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if err := m.hostPolicy()(ctx, name); err != nil { if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err return nil, err
} }
cert, err = m.createCert(ctx, name) cert, err = m.createCert(ctx, ck)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.cachePut(ctx, name, cert) m.cachePut(ctx, ck, cert)
return cert, nil return cert, nil
} }
// wantsTokenCert reports whether a TLS request with SNI is made by a CA server
// for a challenge verification.
func wantsTokenCert(hello *tls.ClientHelloInfo) bool {
// tls-alpn-01
if len(hello.SupportedProtos) == 1 && hello.SupportedProtos[0] == acme.ALPNProto {
return true
}
// tls-sni-xx
return strings.HasSuffix(hello.ServerName, ".acme.invalid")
}
func supportsECDSA(hello *tls.ClientHelloInfo) bool {
// The "signature_algorithms" extension, if present, limits the key exchange
// algorithms allowed by the cipher suites. See RFC 5246, section 7.4.1.4.1.
if hello.SignatureSchemes != nil {
ecdsaOK := false
schemeLoop:
for _, scheme := range hello.SignatureSchemes {
const tlsECDSAWithSHA1 tls.SignatureScheme = 0x0203 // constant added in Go 1.10
switch scheme {
case tlsECDSAWithSHA1, tls.ECDSAWithP256AndSHA256,
tls.ECDSAWithP384AndSHA384, tls.ECDSAWithP521AndSHA512:
ecdsaOK = true
break schemeLoop
}
}
if !ecdsaOK {
return false
}
}
if hello.SupportedCurves != nil {
ecdsaOK := false
for _, curve := range hello.SupportedCurves {
if curve == tls.CurveP256 {
ecdsaOK = true
break
}
}
if !ecdsaOK {
return false
}
}
for _, suite := range hello.CipherSuites {
switch suite {
case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305:
return true
}
}
return false
}
// HTTPHandler configures the Manager to provision ACME "http-01" challenge responses. // HTTPHandler configures the Manager to provision ACME "http-01" challenge responses.
// It returns an http.Handler that responds to the challenges and must be // It returns an http.Handler that responds to the challenges and must be
// running on port 80. If it receives a request that is not an ACME challenge, // running on port 80. If it receives a request that is not an ACME challenge,
@ -252,8 +364,8 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
// Because the fallback handler is run with unencrypted port 80 requests, // Because the fallback handler is run with unencrypted port 80 requests,
// the fallback should not serve TLS-only requests. // the fallback should not serve TLS-only requests.
// //
// If HTTPHandler is never called, the Manager will only use TLS SNI // If HTTPHandler is never called, the Manager will only use the "tls-alpn-01"
// challenges for domain verification. // challenge for domain verification.
func (m *Manager) HTTPHandler(fallback http.Handler) http.Handler { func (m *Manager) HTTPHandler(fallback http.Handler) http.Handler {
m.tokensMu.Lock() m.tokensMu.Lock()
defer m.tokensMu.Unlock() defer m.tokensMu.Unlock()
@ -304,16 +416,16 @@ func stripPort(hostport string) string {
// cert returns an existing certificate either from m.state or cache. // cert returns an existing certificate either from m.state or cache.
// If a certificate is found in cache but not in m.state, the latter will be filled // If a certificate is found in cache but not in m.state, the latter will be filled
// with the cached value. // with the cached value.
func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, error) { func (m *Manager) cert(ctx context.Context, ck certKey) (*tls.Certificate, error) {
m.stateMu.Lock() m.stateMu.Lock()
if s, ok := m.state[name]; ok { if s, ok := m.state[ck]; ok {
m.stateMu.Unlock() m.stateMu.Unlock()
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
return s.tlscert() return s.tlscert()
} }
defer m.stateMu.Unlock() defer m.stateMu.Unlock()
cert, err := m.cacheGet(ctx, name) cert, err := m.cacheGet(ctx, ck)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -322,25 +434,25 @@ func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, erro
return nil, errors.New("acme/autocert: private key cannot sign") return nil, errors.New("acme/autocert: private key cannot sign")
} }
if m.state == nil { if m.state == nil {
m.state = make(map[string]*certState) m.state = make(map[certKey]*certState)
} }
s := &certState{ s := &certState{
key: signer, key: signer,
cert: cert.Certificate, cert: cert.Certificate,
leaf: cert.Leaf, leaf: cert.Leaf,
} }
m.state[name] = s m.state[ck] = s
go m.renew(name, s.key, s.leaf.NotAfter) go m.renew(ck, s.key, s.leaf.NotAfter)
return cert, nil return cert, nil
} }
// cacheGet always returns a valid certificate, or an error otherwise. // cacheGet always returns a valid certificate, or an error otherwise.
// If a cached certficate exists but is not valid, ErrCacheMiss is returned. // If a cached certificate exists but is not valid, ErrCacheMiss is returned.
func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate, error) { func (m *Manager) cacheGet(ctx context.Context, ck certKey) (*tls.Certificate, error) {
if m.Cache == nil { if m.Cache == nil {
return nil, ErrCacheMiss return nil, ErrCacheMiss
} }
data, err := m.Cache.Get(ctx, domain) data, err := m.Cache.Get(ctx, ck.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -371,7 +483,7 @@ func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate
} }
// verify and create TLS cert // verify and create TLS cert
leaf, err := validCert(domain, pubDER, privKey) leaf, err := validCert(ck, pubDER, privKey, m.now())
if err != nil { if err != nil {
return nil, ErrCacheMiss return nil, ErrCacheMiss
} }
@ -383,7 +495,7 @@ func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate
return tlscert, nil return tlscert, nil
} }
func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Certificate) error { func (m *Manager) cachePut(ctx context.Context, ck certKey, tlscert *tls.Certificate) error {
if m.Cache == nil { if m.Cache == nil {
return nil return nil
} }
@ -415,7 +527,7 @@ func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Cert
} }
} }
return m.Cache.Put(ctx, domain, buf.Bytes()) return m.Cache.Put(ctx, ck.String(), buf.Bytes())
} }
func encodeECDSAKey(w io.Writer, key *ecdsa.PrivateKey) error { func encodeECDSAKey(w io.Writer, key *ecdsa.PrivateKey) error {
@ -432,9 +544,9 @@ func encodeECDSAKey(w io.Writer, key *ecdsa.PrivateKey) error {
// //
// If the domain is already being verified, it waits for the existing verification to complete. // If the domain is already being verified, it waits for the existing verification to complete.
// Either way, createCert blocks for the duration of the whole process. // Either way, createCert blocks for the duration of the whole process.
func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certificate, error) { func (m *Manager) createCert(ctx context.Context, ck certKey) (*tls.Certificate, error) {
// TODO: maybe rewrite this whole piece using sync.Once // TODO: maybe rewrite this whole piece using sync.Once
state, err := m.certState(domain) state, err := m.certState(ck)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -452,44 +564,44 @@ func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certifica
defer state.Unlock() defer state.Unlock()
state.locked = false state.locked = false
der, leaf, err := m.authorizedCert(ctx, state.key, domain) der, leaf, err := m.authorizedCert(ctx, state.key, ck)
if err != nil { if err != nil {
// Remove the failed state after some time, // Remove the failed state after some time,
// making the manager call createCert again on the following TLS hello. // making the manager call createCert again on the following TLS hello.
time.AfterFunc(createCertRetryAfter, func() { time.AfterFunc(createCertRetryAfter, func() {
defer testDidRemoveState(domain) defer testDidRemoveState(ck)
m.stateMu.Lock() m.stateMu.Lock()
defer m.stateMu.Unlock() defer m.stateMu.Unlock()
// Verify the state hasn't changed and it's still invalid // Verify the state hasn't changed and it's still invalid
// before deleting. // before deleting.
s, ok := m.state[domain] s, ok := m.state[ck]
if !ok { if !ok {
return return
} }
if _, err := validCert(domain, s.cert, s.key); err == nil { if _, err := validCert(ck, s.cert, s.key, m.now()); err == nil {
return return
} }
delete(m.state, domain) delete(m.state, ck)
}) })
return nil, err return nil, err
} }
state.cert = der state.cert = der
state.leaf = leaf state.leaf = leaf
go m.renew(domain, state.key, state.leaf.NotAfter) go m.renew(ck, state.key, state.leaf.NotAfter)
return state.tlscert() return state.tlscert()
} }
// certState returns a new or existing certState. // certState returns a new or existing certState.
// If a new certState is returned, state.exist is false and the state is locked. // If a new certState is returned, state.exist is false and the state is locked.
// The returned error is non-nil only in the case where a new state could not be created. // The returned error is non-nil only in the case where a new state could not be created.
func (m *Manager) certState(domain string) (*certState, error) { func (m *Manager) certState(ck certKey) (*certState, error) {
m.stateMu.Lock() m.stateMu.Lock()
defer m.stateMu.Unlock() defer m.stateMu.Unlock()
if m.state == nil { if m.state == nil {
m.state = make(map[string]*certState) m.state = make(map[certKey]*certState)
} }
// existing state // existing state
if state, ok := m.state[domain]; ok { if state, ok := m.state[ck]; ok {
return state, nil return state, nil
} }
@ -498,7 +610,7 @@ func (m *Manager) certState(domain string) (*certState, error) {
err error err error
key crypto.Signer key crypto.Signer
) )
if m.ForceRSA { if ck.isRSA {
key, err = rsa.GenerateKey(rand.Reader, 2048) key, err = rsa.GenerateKey(rand.Reader, 2048)
} else { } else {
key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -512,22 +624,22 @@ func (m *Manager) certState(domain string) (*certState, error) {
locked: true, locked: true,
} }
state.Lock() // will be unlocked by m.certState caller state.Lock() // will be unlocked by m.certState caller
m.state[domain] = state m.state[ck] = state
return state, nil return state, nil
} }
// authorizedCert starts the domain ownership verification process and requests a new cert upon success. // authorizedCert starts the domain ownership verification process and requests a new cert upon success.
// The key argument is the certificate private key. // The key argument is the certificate private key.
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) { func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, ck certKey) (der [][]byte, leaf *x509.Certificate, err error) {
client, err := m.acmeClient(ctx) client, err := m.acmeClient(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if err := m.verify(ctx, client, domain); err != nil { if err := m.verify(ctx, client, ck.domain); err != nil {
return nil, nil, err return nil, nil, err
} }
csr, err := certRequest(key, domain) csr, err := certRequest(key, ck.domain, m.ExtraExtensions)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -535,25 +647,55 @@ func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
leaf, err = validCert(domain, der, key) leaf, err = validCert(ck, der, key, m.now())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return der, leaf, nil return der, leaf, nil
} }
// revokePendingAuthz revokes all authorizations idenfied by the elements of uri slice.
// It ignores revocation errors.
func (m *Manager) revokePendingAuthz(ctx context.Context, uri []string) {
client, err := m.acmeClient(ctx)
if err != nil {
return
}
for _, u := range uri {
client.RevokeAuthorization(ctx, u)
}
}
// verify runs the identifier (domain) authorization flow // verify runs the identifier (domain) authorization flow
// using each applicable ACME challenge type. // using each applicable ACME challenge type.
func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string) error { func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string) error {
// The list of challenge types we'll try to fulfill // The list of challenge types we'll try to fulfill
// in this specific order. // in this specific order.
challengeTypes := []string{"tls-sni-02", "tls-sni-01"} challengeTypes := []string{"tls-alpn-01", "tls-sni-02", "tls-sni-01"}
m.tokensMu.RLock() m.tokensMu.RLock()
if m.tryHTTP01 { if m.tryHTTP01 {
challengeTypes = append(challengeTypes, "http-01") challengeTypes = append(challengeTypes, "http-01")
} }
m.tokensMu.RUnlock() m.tokensMu.RUnlock()
// Keep track of pending authzs and revoke the ones that did not validate.
pendingAuthzs := make(map[string]bool)
defer func() {
var uri []string
for k, pending := range pendingAuthzs {
if pending {
uri = append(uri, k)
}
}
if len(uri) > 0 {
// Use "detached" background context.
// The revocations need not happen in the current verification flow.
go m.revokePendingAuthz(context.Background(), uri)
}
}()
// errs accumulates challenge failure errors, printed if all fail
errs := make(map[*acme.Challenge]error)
var nextTyp int // challengeType index of the next challenge type to try var nextTyp int // challengeType index of the next challenge type to try
for { for {
// Start domain authorization and get the challenge. // Start domain authorization and get the challenge.
@ -570,6 +712,8 @@ func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string
return fmt.Errorf("acme/autocert: invalid authorization %q", authz.URI) return fmt.Errorf("acme/autocert: invalid authorization %q", authz.URI)
} }
pendingAuthzs[authz.URI] = true
// Pick the next preferred challenge. // Pick the next preferred challenge.
var chal *acme.Challenge var chal *acme.Challenge
for chal == nil && nextTyp < len(challengeTypes) { for chal == nil && nextTyp < len(challengeTypes) {
@ -577,28 +721,44 @@ func (m *Manager) verify(ctx context.Context, client *acme.Client, domain string
nextTyp++ nextTyp++
} }
if chal == nil { if chal == nil {
return fmt.Errorf("acme/autocert: unable to authorize %q; tried %q", domain, challengeTypes) errorMsg := fmt.Sprintf("acme/autocert: unable to authorize %q", domain)
for chal, err := range errs {
errorMsg += fmt.Sprintf("; challenge %q failed with error: %v", chal.Type, err)
} }
cleanup, err := m.fulfill(ctx, client, chal) return errors.New(errorMsg)
}
cleanup, err := m.fulfill(ctx, client, chal, domain)
if err != nil { if err != nil {
errs[chal] = err
continue continue
} }
defer cleanup() defer cleanup()
if _, err := client.Accept(ctx, chal); err != nil { if _, err := client.Accept(ctx, chal); err != nil {
errs[chal] = err
continue continue
} }
// A challenge is fulfilled and accepted: wait for the CA to validate. // A challenge is fulfilled and accepted: wait for the CA to validate.
if _, err := client.WaitAuthorization(ctx, authz.URI); err == nil { if _, err := client.WaitAuthorization(ctx, authz.URI); err != nil {
return nil errs[chal] = err
continue
} }
delete(pendingAuthzs, authz.URI)
return nil
} }
} }
// fulfill provisions a response to the challenge chal. // fulfill provisions a response to the challenge chal.
// The cleanup is non-nil only if provisioning succeeded. // The cleanup is non-nil only if provisioning succeeded.
func (m *Manager) fulfill(ctx context.Context, client *acme.Client, chal *acme.Challenge) (cleanup func(), err error) { func (m *Manager) fulfill(ctx context.Context, client *acme.Client, chal *acme.Challenge, domain string) (cleanup func(), err error) {
switch chal.Type { switch chal.Type {
case "tls-alpn-01":
cert, err := client.TLSALPN01ChallengeCert(chal.Token, domain)
if err != nil {
return nil, err
}
m.putCertToken(ctx, domain, &cert)
return func() { go m.deleteCertToken(domain) }, nil
case "tls-sni-01": case "tls-sni-01":
cert, name, err := client.TLSSNI01ChallengeCert(chal.Token) cert, name, err := client.TLSSNI01ChallengeCert(chal.Token)
if err != nil { if err != nil {
@ -634,8 +794,8 @@ func pickChallenge(typ string, chal []*acme.Challenge) *acme.Challenge {
return nil return nil
} }
// putCertToken stores the cert under the named key in both m.certTokens map // putCertToken stores the token certificate with the specified name
// and m.Cache. // in both m.certTokens map and m.Cache.
func (m *Manager) putCertToken(ctx context.Context, name string, cert *tls.Certificate) { func (m *Manager) putCertToken(ctx context.Context, name string, cert *tls.Certificate) {
m.tokensMu.Lock() m.tokensMu.Lock()
defer m.tokensMu.Unlock() defer m.tokensMu.Unlock()
@ -643,17 +803,18 @@ func (m *Manager) putCertToken(ctx context.Context, name string, cert *tls.Certi
m.certTokens = make(map[string]*tls.Certificate) m.certTokens = make(map[string]*tls.Certificate)
} }
m.certTokens[name] = cert m.certTokens[name] = cert
m.cachePut(ctx, name, cert) m.cachePut(ctx, certKey{domain: name, isToken: true}, cert)
} }
// deleteCertToken removes the token certificate for the specified domain name // deleteCertToken removes the token certificate with the specified name
// from both m.certTokens map and m.Cache. // from both m.certTokens map and m.Cache.
func (m *Manager) deleteCertToken(name string) { func (m *Manager) deleteCertToken(name string) {
m.tokensMu.Lock() m.tokensMu.Lock()
defer m.tokensMu.Unlock() defer m.tokensMu.Unlock()
delete(m.certTokens, name) delete(m.certTokens, name)
if m.Cache != nil { if m.Cache != nil {
m.Cache.Delete(context.Background(), name) ck := certKey{domain: name, isToken: true}
m.Cache.Delete(context.Background(), ck.String())
} }
} }
@ -704,7 +865,7 @@ func (m *Manager) deleteHTTPToken(tokenPath string) {
// httpTokenCacheKey returns a key at which an http-01 token value may be stored // httpTokenCacheKey returns a key at which an http-01 token value may be stored
// in the Manager's optional Cache. // in the Manager's optional Cache.
func httpTokenCacheKey(tokenPath string) string { func httpTokenCacheKey(tokenPath string) string {
return "http-01-" + path.Base(tokenPath) return path.Base(tokenPath) + "+http-01"
} }
// renew starts a cert renewal timer loop, one per domain. // renew starts a cert renewal timer loop, one per domain.
@ -715,18 +876,18 @@ func httpTokenCacheKey(tokenPath string) string {
// //
// The key argument is a certificate private key. // The key argument is a certificate private key.
// The exp argument is the cert expiration time (NotAfter). // The exp argument is the cert expiration time (NotAfter).
func (m *Manager) renew(domain string, key crypto.Signer, exp time.Time) { func (m *Manager) renew(ck certKey, key crypto.Signer, exp time.Time) {
m.renewalMu.Lock() m.renewalMu.Lock()
defer m.renewalMu.Unlock() defer m.renewalMu.Unlock()
if m.renewal[domain] != nil { if m.renewal[ck] != nil {
// another goroutine is already on it // another goroutine is already on it
return return
} }
if m.renewal == nil { if m.renewal == nil {
m.renewal = make(map[string]*domainRenewal) m.renewal = make(map[certKey]*domainRenewal)
} }
dr := &domainRenewal{m: m, domain: domain, key: key} dr := &domainRenewal{m: m, ck: ck, key: key}
m.renewal[domain] = dr m.renewal[ck] = dr
dr.start(exp) dr.start(exp)
} }
@ -742,7 +903,10 @@ func (m *Manager) stopRenew() {
} }
func (m *Manager) accountKey(ctx context.Context) (crypto.Signer, error) { func (m *Manager) accountKey(ctx context.Context) (crypto.Signer, error) {
const keyName = "acme_account.key" const keyName = "acme_account+key"
// Previous versions of autocert stored the value under a different key.
const legacyKeyName = "acme_account.key"
genKey := func() (*ecdsa.PrivateKey, error) { genKey := func() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -753,6 +917,9 @@ func (m *Manager) accountKey(ctx context.Context) (crypto.Signer, error) {
} }
data, err := m.Cache.Get(ctx, keyName) data, err := m.Cache.Get(ctx, keyName)
if err == ErrCacheMiss {
data, err = m.Cache.Get(ctx, legacyKeyName)
}
if err == ErrCacheMiss { if err == ErrCacheMiss {
key, err := genKey() key, err := genKey()
if err != nil { if err != nil {
@ -824,6 +991,13 @@ func (m *Manager) renewBefore() time.Duration {
return 720 * time.Hour // 30 days return 720 * time.Hour // 30 days
} }
func (m *Manager) now() time.Time {
if m.nowFunc != nil {
return m.nowFunc()
}
return time.Now()
}
// certState is ready when its mutex is unlocked for reading. // certState is ready when its mutex is unlocked for reading.
type certState struct { type certState struct {
sync.RWMutex sync.RWMutex
@ -849,12 +1023,12 @@ func (s *certState) tlscert() (*tls.Certificate, error) {
}, nil }, nil
} }
// certRequest creates a certificate request for the given common name cn // certRequest generates a CSR for the given common name cn and optional SANs.
// and optional SANs. func certRequest(key crypto.Signer, cn string, ext []pkix.Extension, san ...string) ([]byte, error) {
func certRequest(key crypto.Signer, cn string, san ...string) ([]byte, error) {
req := &x509.CertificateRequest{ req := &x509.CertificateRequest{
Subject: pkix.Name{CommonName: cn}, Subject: pkix.Name{CommonName: cn},
DNSNames: san, DNSNames: san,
ExtraExtensions: ext,
} }
return x509.CreateCertificateRequest(rand.Reader, req, key) return x509.CreateCertificateRequest(rand.Reader, req, key)
} }
@ -885,12 +1059,12 @@ func parsePrivateKey(der []byte) (crypto.Signer, error) {
return nil, errors.New("acme/autocert: failed to parse private key") return nil, errors.New("acme/autocert: failed to parse private key")
} }
// validCert parses a cert chain provided as der argument and verifies the leaf, der[0], // validCert parses a cert chain provided as der argument and verifies the leaf and der[0]
// corresponds to the private key, as well as the domain match and expiration dates. // correspond to the private key, the domain and key type match, and expiration dates
// It doesn't do any revocation checking. // are valid. It doesn't do any revocation checking.
// //
// The returned value is the verified leaf cert. // The returned value is the verified leaf cert.
func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certificate, err error) { func validCert(ck certKey, der [][]byte, key crypto.Signer, now time.Time) (leaf *x509.Certificate, err error) {
// parse public part(s) // parse public part(s)
var n int var n int
for _, b := range der { for _, b := range der {
@ -902,22 +1076,21 @@ func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certi
n += copy(pub[n:], b) n += copy(pub[n:], b)
} }
x509Cert, err := x509.ParseCertificates(pub) x509Cert, err := x509.ParseCertificates(pub)
if len(x509Cert) == 0 { if err != nil || len(x509Cert) == 0 {
return nil, errors.New("acme/autocert: no public key found") return nil, errors.New("acme/autocert: no public key found")
} }
// verify the leaf is not expired and matches the domain name // verify the leaf is not expired and matches the domain name
leaf = x509Cert[0] leaf = x509Cert[0]
now := timeNow()
if now.Before(leaf.NotBefore) { if now.Before(leaf.NotBefore) {
return nil, errors.New("acme/autocert: certificate is not valid yet") return nil, errors.New("acme/autocert: certificate is not valid yet")
} }
if now.After(leaf.NotAfter) { if now.After(leaf.NotAfter) {
return nil, errors.New("acme/autocert: expired certificate") return nil, errors.New("acme/autocert: expired certificate")
} }
if err := leaf.VerifyHostname(domain); err != nil { if err := leaf.VerifyHostname(ck.domain); err != nil {
return nil, err return nil, err
} }
// ensure the leaf corresponds to the private key // ensure the leaf corresponds to the private key and matches the certKey type
switch pub := leaf.PublicKey.(type) { switch pub := leaf.PublicKey.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
prv, ok := key.(*rsa.PrivateKey) prv, ok := key.(*rsa.PrivateKey)
@ -927,6 +1100,9 @@ func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certi
if pub.N.Cmp(prv.N) != 0 { if pub.N.Cmp(prv.N) != 0 {
return nil, errors.New("acme/autocert: private key does not match public key") return nil, errors.New("acme/autocert: private key does not match public key")
} }
if !ck.isRSA && !ck.isToken {
return nil, errors.New("acme/autocert: key type does not match expected value")
}
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
prv, ok := key.(*ecdsa.PrivateKey) prv, ok := key.(*ecdsa.PrivateKey)
if !ok { if !ok {
@ -935,6 +1111,9 @@ func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certi
if pub.X.Cmp(prv.X) != 0 || pub.Y.Cmp(prv.Y) != 0 { if pub.X.Cmp(prv.X) != 0 || pub.Y.Cmp(prv.Y) != 0 {
return nil, errors.New("acme/autocert: private key does not match public key") return nil, errors.New("acme/autocert: private key does not match public key")
} }
if ck.isRSA && !ck.isToken {
return nil, errors.New("acme/autocert: key type does not match expected value")
}
default: default:
return nil, errors.New("acme/autocert: unknown public key algorithm") return nil, errors.New("acme/autocert: unknown public key algorithm")
} }
@ -955,8 +1134,6 @@ func (r *lockedMathRand) int63n(max int64) int64 {
// For easier testing. // For easier testing.
var ( var (
timeNow = time.Now
// Called when a state is removed. // Called when a state is removed.
testDidRemoveState = func(domain string) {} testDidRemoveState = func(certKey) {}
) )

View File

@ -16,10 +16,10 @@ import (
var ErrCacheMiss = errors.New("acme/autocert: certificate cache miss") var ErrCacheMiss = errors.New("acme/autocert: certificate cache miss")
// Cache is used by Manager to store and retrieve previously obtained certificates // Cache is used by Manager to store and retrieve previously obtained certificates
// as opaque data. // and other account data as opaque blobs.
// //
// The key argument of the methods refers to a domain name but need not be an FQDN. // Cache implementations should not rely on the key naming pattern. Keys can
// Cache implementations should not rely on the key naming pattern. // include any printable ASCII characters, except the following: \/:*?"<>|
type Cache interface { type Cache interface {
// Get returns a certificate data for the specified key. // Get returns a certificate data for the specified key.
// If there's no such key, Get returns ErrCacheMiss. // If there's no such key, Get returns ErrCacheMiss.

View File

@ -73,10 +73,7 @@ func NewListener(domains ...string) net.Listener {
func (m *Manager) Listener() net.Listener { func (m *Manager) Listener() net.Listener {
ln := &listener{ ln := &listener{
m: m, m: m,
conf: &tls.Config{ conf: m.TLSConfig(),
GetCertificate: m.GetCertificate, // bonus: panic on nil m
NextProtos: []string{"h2", "http/1.1"}, // Enable HTTP/2
},
} }
ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443") ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")
return ln return ln

View File

@ -18,7 +18,7 @@ const renewJitter = time.Hour
// renewing a single domain's cert. // renewing a single domain's cert.
type domainRenewal struct { type domainRenewal struct {
m *Manager m *Manager
domain string ck certKey
key crypto.Signer key crypto.Signer
timerMu sync.Mutex timerMu sync.Mutex
@ -71,25 +71,43 @@ func (dr *domainRenewal) renew() {
testDidRenewLoop(next, err) testDidRenewLoop(next, err)
} }
// updateState locks and replaces the relevant Manager.state item with the given
// state. It additionally updates dr.key with the given state's key.
func (dr *domainRenewal) updateState(state *certState) {
dr.m.stateMu.Lock()
defer dr.m.stateMu.Unlock()
dr.key = state.key
dr.m.state[dr.ck] = state
}
// do is similar to Manager.createCert but it doesn't lock a Manager.state item. // do is similar to Manager.createCert but it doesn't lock a Manager.state item.
// Instead, it requests a new certificate independently and, upon success, // Instead, it requests a new certificate independently and, upon success,
// replaces dr.m.state item with a new one and updates cache for the given domain. // replaces dr.m.state item with a new one and updates cache for the given domain.
// //
// It may return immediately if the expiration date of the currently cached cert // It may lock and update the Manager.state if the expiration date of the currently
// is far enough in the future. // cached cert is far enough in the future.
// //
// The returned value is a time interval after which the renewal should occur again. // The returned value is a time interval after which the renewal should occur again.
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment // a race is likely unavoidable in a distributed environment
// but we try nonetheless // but we try nonetheless
if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil { if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil {
next := dr.next(tlscert.Leaf.NotAfter) next := dr.next(tlscert.Leaf.NotAfter)
if next > dr.m.renewBefore()+renewJitter { if next > dr.m.renewBefore()+renewJitter {
signer, ok := tlscert.PrivateKey.(crypto.Signer)
if ok {
state := &certState{
key: signer,
cert: tlscert.Certificate,
leaf: tlscert.Leaf,
}
dr.updateState(state)
return next, nil return next, nil
} }
} }
}
der, leaf, err := dr.m.authorizedCert(ctx, dr.key, dr.domain) der, leaf, err := dr.m.authorizedCert(ctx, dr.key, dr.ck)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -102,16 +120,15 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
dr.m.cachePut(ctx, dr.domain, tlscert) if err := dr.m.cachePut(ctx, dr.ck, tlscert); err != nil {
dr.m.stateMu.Lock() return 0, err
defer dr.m.stateMu.Unlock() }
// m.state is guaranteed to be non-nil at this point dr.updateState(state)
dr.m.state[dr.domain] = state
return dr.next(leaf.NotAfter), nil return dr.next(leaf.NotAfter), nil
} }
func (dr *domainRenewal) next(expiry time.Time) time.Duration { func (dr *domainRenewal) next(expiry time.Time) time.Duration {
d := expiry.Sub(timeNow()) - dr.m.renewBefore() d := expiry.Sub(dr.m.now()) - dr.m.renewBefore()
// add a bit of randomness to renew deadline // add a bit of randomness to renew deadline
n := pseudoRand.int63n(int64(renewJitter)) n := pseudoRand.int63n(int64(renewJitter))
d -= time.Duration(n) d -= time.Duration(n)

281
vendor/golang.org/x/crypto/acme/http.go generated vendored Normal file
View File

@ -0,0 +1,281 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package acme
import (
"bytes"
"context"
"crypto"
"crypto/rand"
"encoding/json"
"fmt"
"io/ioutil"
"math/big"
"net/http"
"strconv"
"strings"
"time"
)
// retryTimer encapsulates common logic for retrying unsuccessful requests.
// It is not safe for concurrent use.
type retryTimer struct {
// backoffFn provides backoff delay sequence for retries.
// See Client.RetryBackoff doc comment.
backoffFn func(n int, r *http.Request, res *http.Response) time.Duration
// n is the current retry attempt.
n int
}
func (t *retryTimer) inc() {
t.n++
}
// backoff pauses the current goroutine as described in Client.RetryBackoff.
func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Response) error {
d := t.backoffFn(t.n, r, res)
if d <= 0 {
return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", r.URL, t.n)
}
wakeup := time.NewTimer(d)
defer wakeup.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-wakeup.C:
return nil
}
}
func (c *Client) retryTimer() *retryTimer {
f := c.RetryBackoff
if f == nil {
f = defaultBackoff
}
return &retryTimer{backoffFn: f}
}
// defaultBackoff provides default Client.RetryBackoff implementation
// using a truncated exponential backoff algorithm,
// as described in Client.RetryBackoff.
//
// The n argument is always bounded between 1 and 30.
// The returned value is always greater than 0.
func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration {
const max = 10 * time.Second
var jitter time.Duration
if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil {
// Set the minimum to 1ms to avoid a case where
// an invalid Retry-After value is parsed into 0 below,
// resulting in the 0 returned value which would unintentionally
// stop the retries.
jitter = (1 + time.Duration(x.Int64())) * time.Millisecond
}
if v, ok := res.Header["Retry-After"]; ok {
return retryAfter(v[0]) + jitter
}
if n < 1 {
n = 1
}
if n > 30 {
n = 30
}
d := time.Duration(1<<uint(n-1))*time.Second + jitter
if d > max {
return max
}
return d
}
// retryAfter parses a Retry-After HTTP header value,
// trying to convert v into an int (seconds) or use http.ParseTime otherwise.
// It returns zero value if v cannot be parsed.
func retryAfter(v string) time.Duration {
if i, err := strconv.Atoi(v); err == nil {
return time.Duration(i) * time.Second
}
t, err := http.ParseTime(v)
if err != nil {
return 0
}
return t.Sub(timeNow())
}
// resOkay is a function that reports whether the provided response is okay.
// It is expected to keep the response body unread.
type resOkay func(*http.Response) bool
// wantStatus returns a function which reports whether the code
// matches the status code of a response.
func wantStatus(codes ...int) resOkay {
return func(res *http.Response) bool {
for _, code := range codes {
if code == res.StatusCode {
return true
}
}
return false
}
}
// get issues an unsigned GET request to the specified URL.
// It returns a non-error value only when ok reports true.
//
// get retries unsuccessful attempts according to c.RetryBackoff
// until the context is done or a non-retriable error is received.
func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) {
retry := c.retryTimer()
for {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
res, err := c.doNoRetry(ctx, req)
switch {
case err != nil:
return nil, err
case ok(res):
return res, nil
case isRetriable(res.StatusCode):
retry.inc()
resErr := responseError(res)
res.Body.Close()
// Ignore the error value from retry.backoff
// and return the one from last retry, as received from the CA.
if retry.backoff(ctx, req, res) != nil {
return nil, resErr
}
default:
defer res.Body.Close()
return nil, responseError(res)
}
}
}
// post issues a signed POST request in JWS format using the provided key
// to the specified URL.
// It returns a non-error value only when ok reports true.
//
// post retries unsuccessful attempts according to c.RetryBackoff
// until the context is done or a non-retriable error is received.
// It uses postNoRetry to make individual requests.
func (c *Client) post(ctx context.Context, key crypto.Signer, url string, body interface{}, ok resOkay) (*http.Response, error) {
retry := c.retryTimer()
for {
res, req, err := c.postNoRetry(ctx, key, url, body)
if err != nil {
return nil, err
}
if ok(res) {
return res, nil
}
resErr := responseError(res)
res.Body.Close()
switch {
// Check for bad nonce before isRetriable because it may have been returned
// with an unretriable response code such as 400 Bad Request.
case isBadNonce(resErr):
// Consider any previously stored nonce values to be invalid.
c.clearNonces()
case !isRetriable(res.StatusCode):
return nil, resErr
}
retry.inc()
// Ignore the error value from retry.backoff
// and return the one from last retry, as received from the CA.
if err := retry.backoff(ctx, req, res); err != nil {
return nil, resErr
}
}
}
// postNoRetry signs the body with the given key and POSTs it to the provided url.
// The body argument must be JSON-serializable.
// It is used by c.post to retry unsuccessful attempts.
func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, *http.Request, error) {
nonce, err := c.popNonce(ctx, url)
if err != nil {
return nil, nil, err
}
b, err := jwsEncodeJSON(body, key, nonce)
if err != nil {
return nil, nil, err
}
req, err := http.NewRequest("POST", url, bytes.NewReader(b))
if err != nil {
return nil, nil, err
}
req.Header.Set("Content-Type", "application/jose+json")
res, err := c.doNoRetry(ctx, req)
if err != nil {
return nil, nil, err
}
c.addNonce(res.Header)
return res, req, nil
}
// doNoRetry issues a request req, replacing its context (if any) with ctx.
func (c *Client) doNoRetry(ctx context.Context, req *http.Request) (*http.Response, error) {
res, err := c.httpClient().Do(req.WithContext(ctx))
if err != nil {
select {
case <-ctx.Done():
// Prefer the unadorned context error.
// (The acme package had tests assuming this, previously from ctxhttp's
// behavior, predating net/http supporting contexts natively)
// TODO(bradfitz): reconsider this in the future. But for now this
// requires no test updates.
return nil, ctx.Err()
default:
return nil, err
}
}
return res, nil
}
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
}
// isBadNonce reports whether err is an ACME "badnonce" error.
func isBadNonce(err error) bool {
// According to the spec badNonce is urn:ietf:params:acme:error:badNonce.
// However, ACME servers in the wild return their versions of the error.
// See https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4
// and https://github.com/letsencrypt/boulder/blob/0e07eacb/docs/acme-divergences.md#section-66.
ae, ok := err.(*Error)
return ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce")
}
// isRetriable reports whether a request can be retried
// based on the response status code.
//
// Note that a "bad nonce" error is returned with a non-retriable 400 Bad Request code.
// Callers should parse the response and check with isBadNonce.
func isRetriable(code int) bool {
return code <= 399 || code >= 500 || code == http.StatusTooManyRequests
}
// responseError creates an error of Error type from resp.
func responseError(resp *http.Response) error {
// don't care if ReadAll returns an error:
// json.Unmarshal will fail in that case anyway
b, _ := ioutil.ReadAll(resp.Body)
e := &wireError{Status: resp.StatusCode}
if err := json.Unmarshal(b, e); err != nil {
// this is not a regular error response:
// populate detail with anything we received,
// e.Status will already contain HTTP response code value
e.Detail = string(b)
if e.Detail == "" {
e.Detail = resp.Status
}
}
return e.error(resp.Header)
}

View File

@ -104,7 +104,7 @@ func RateLimit(err error) (time.Duration, bool) {
if e.Header == nil { if e.Header == nil {
return 0, true return 0, true
} }
return retryAfter(e.Header.Get("Retry-After"), 0), true return retryAfter(e.Header.Get("Retry-After")), true
} }
// Account is a user account. It is associated with a private key. // Account is a user account. It is associated with a private key.
@ -296,8 +296,8 @@ func (e *wireError) error(h http.Header) *Error {
} }
} }
// CertOption is an optional argument type for the TLSSNIxChallengeCert methods for // CertOption is an optional argument type for the TLS ChallengeCert methods for
// customizing a temporary certificate for TLS-SNI challenges. // customizing a temporary certificate for TLS-based challenges.
type CertOption interface { type CertOption interface {
privateCertOpt() privateCertOpt()
} }
@ -317,7 +317,7 @@ func (*certOptKey) privateCertOpt() {}
// WithTemplate creates an option for specifying a certificate template. // WithTemplate creates an option for specifying a certificate template.
// See x509.CreateCertificate for template usage details. // See x509.CreateCertificate for template usage details.
// //
// In TLSSNIxChallengeCert methods, the template is also used as parent, // In TLS ChallengeCert methods, the template is also used as parent,
// resulting in a self-signed certificate. // resulting in a self-signed certificate.
// The DNSNames field of t is always overwritten for tls-sni challenge certs. // The DNSNames field of t is always overwritten for tls-sni challenge certs.
func WithTemplate(t *x509.Certificate) CertOption { func WithTemplate(t *x509.Certificate) CertOption {

View File

@ -6,7 +6,10 @@
// https://ed25519.cr.yp.to/. // https://ed25519.cr.yp.to/.
// //
// These functions are also compatible with the “Ed25519” function defined in // These functions are also compatible with the “Ed25519” function defined in
// RFC 8032. // RFC 8032. However, unlike RFC 8032's formulation, this package's private key
// representation includes a public key suffix to make multiple signing
// operations with the same key more efficient. This package refers to the RFC
// 8032 private key as the “seed”.
package ed25519 package ed25519
// This code is a port of the public domain, “ref10” implementation of ed25519 // This code is a port of the public domain, “ref10” implementation of ed25519
@ -31,6 +34,8 @@ const (
PrivateKeySize = 64 PrivateKeySize = 64
// SignatureSize is the size, in bytes, of signatures generated and verified by this package. // SignatureSize is the size, in bytes, of signatures generated and verified by this package.
SignatureSize = 64 SignatureSize = 64
// SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032.
SeedSize = 32
) )
// PublicKey is the type of Ed25519 public keys. // PublicKey is the type of Ed25519 public keys.
@ -46,6 +51,15 @@ func (priv PrivateKey) Public() crypto.PublicKey {
return PublicKey(publicKey) return PublicKey(publicKey)
} }
// Seed returns the private key seed corresponding to priv. It is provided for
// interoperability with RFC 8032. RFC 8032's private keys correspond to seeds
// in this package.
func (priv PrivateKey) Seed() []byte {
seed := make([]byte, SeedSize)
copy(seed, priv[:32])
return seed
}
// Sign signs the given message with priv. // Sign signs the given message with priv.
// Ed25519 performs two passes over messages to be signed and therefore cannot // Ed25519 performs two passes over messages to be signed and therefore cannot
// handle pre-hashed messages. Thus opts.HashFunc() must return zero to // handle pre-hashed messages. Thus opts.HashFunc() must return zero to
@ -61,19 +75,33 @@ func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOp
// GenerateKey generates a public/private key pair using entropy from rand. // GenerateKey generates a public/private key pair using entropy from rand.
// If rand is nil, crypto/rand.Reader will be used. // If rand is nil, crypto/rand.Reader will be used.
func GenerateKey(rand io.Reader) (publicKey PublicKey, privateKey PrivateKey, err error) { func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) {
if rand == nil { if rand == nil {
rand = cryptorand.Reader rand = cryptorand.Reader
} }
privateKey = make([]byte, PrivateKeySize) seed := make([]byte, SeedSize)
publicKey = make([]byte, PublicKeySize) if _, err := io.ReadFull(rand, seed); err != nil {
_, err = io.ReadFull(rand, privateKey[:32])
if err != nil {
return nil, nil, err return nil, nil, err
} }
digest := sha512.Sum512(privateKey[:32]) privateKey := NewKeyFromSeed(seed)
publicKey := make([]byte, PublicKeySize)
copy(publicKey, privateKey[32:])
return publicKey, privateKey, nil
}
// NewKeyFromSeed calculates a private key from a seed. It will panic if
// len(seed) is not SeedSize. This function is provided for interoperability
// with RFC 8032. RFC 8032's private keys correspond to seeds in this
// package.
func NewKeyFromSeed(seed []byte) PrivateKey {
if l := len(seed); l != SeedSize {
panic("ed25519: bad seed length: " + strconv.Itoa(l))
}
digest := sha512.Sum512(seed)
digest[0] &= 248 digest[0] &= 248
digest[31] &= 127 digest[31] &= 127
digest[31] |= 64 digest[31] |= 64
@ -85,10 +113,11 @@ func GenerateKey(rand io.Reader) (publicKey PublicKey, privateKey PrivateKey, er
var publicKeyBytes [32]byte var publicKeyBytes [32]byte
A.ToBytes(&publicKeyBytes) A.ToBytes(&publicKeyBytes)
privateKey := make([]byte, PrivateKeySize)
copy(privateKey, seed)
copy(privateKey[32:], publicKeyBytes[:]) copy(privateKey[32:], publicKeyBytes[:])
copy(publicKey, publicKeyBytes[:])
return publicKey, privateKey, nil return privateKey
} }
// Sign signs the message with privateKey and returns a signature. It will // Sign signs the message with privateKey and returns a signature. It will
@ -171,9 +200,16 @@ func Verify(publicKey PublicKey, message, sig []byte) bool {
edwards25519.ScReduce(&hReduced, &digest) edwards25519.ScReduce(&hReduced, &digest)
var R edwards25519.ProjectiveGroupElement var R edwards25519.ProjectiveGroupElement
var b [32]byte var s [32]byte
copy(b[:], sig[32:]) copy(s[:], sig[32:])
edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &b)
// https://tools.ietf.org/html/rfc8032#section-5.1.7 requires that s be in
// the range [0, order) in order to prevent signature malleability.
if !edwards25519.ScMinimal(&s) {
return false
}
edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &s)
var checkR [32]byte var checkR [32]byte
R.ToBytes(&checkR) R.ToBytes(&checkR)

View File

@ -4,6 +4,8 @@
package edwards25519 package edwards25519
import "encoding/binary"
// This code is a port of the public domain, “ref10” implementation of ed25519 // This code is a port of the public domain, “ref10” implementation of ed25519
// from SUPERCOP. // from SUPERCOP.
@ -1769,3 +1771,23 @@ func ScReduce(out *[32]byte, s *[64]byte) {
out[30] = byte(s11 >> 9) out[30] = byte(s11 >> 9)
out[31] = byte(s11 >> 17) out[31] = byte(s11 >> 17)
} }
// order is the order of Curve25519 in little-endian form.
var order = [4]uint64{0x5812631a5cf5d3ed, 0x14def9dea2f79cd6, 0, 0x1000000000000000}
// ScMinimal returns true if the given scalar is less than the order of the
// curve.
func ScMinimal(scalar *[32]byte) bool {
for i := 3; ; i-- {
v := binary.LittleEndian.Uint64(scalar[i*8:])
if v > order[i] {
return false
} else if v < order[i] {
break
} else if i == 0 {
return false
}
}
return true
}

View File

@ -2,197 +2,263 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package ChaCha20 implements the core ChaCha20 function as specified in https://tools.ietf.org/html/rfc7539#section-2.3. // Package ChaCha20 implements the core ChaCha20 function as specified
// in https://tools.ietf.org/html/rfc7539#section-2.3.
package chacha20 package chacha20
import "encoding/binary" import (
"crypto/cipher"
"encoding/binary"
const rounds = 20 "golang.org/x/crypto/internal/subtle"
)
// core applies the ChaCha20 core function to 16-byte input in, 32-byte key k, // assert that *Cipher implements cipher.Stream
// and 16-byte constant c, and puts the result into 64-byte array out. var _ cipher.Stream = (*Cipher)(nil)
func core(out *[64]byte, in *[16]byte, k *[32]byte) {
j0 := uint32(0x61707865)
j1 := uint32(0x3320646e)
j2 := uint32(0x79622d32)
j3 := uint32(0x6b206574)
j4 := binary.LittleEndian.Uint32(k[0:4])
j5 := binary.LittleEndian.Uint32(k[4:8])
j6 := binary.LittleEndian.Uint32(k[8:12])
j7 := binary.LittleEndian.Uint32(k[12:16])
j8 := binary.LittleEndian.Uint32(k[16:20])
j9 := binary.LittleEndian.Uint32(k[20:24])
j10 := binary.LittleEndian.Uint32(k[24:28])
j11 := binary.LittleEndian.Uint32(k[28:32])
j12 := binary.LittleEndian.Uint32(in[0:4])
j13 := binary.LittleEndian.Uint32(in[4:8])
j14 := binary.LittleEndian.Uint32(in[8:12])
j15 := binary.LittleEndian.Uint32(in[12:16])
x0, x1, x2, x3, x4, x5, x6, x7 := j0, j1, j2, j3, j4, j5, j6, j7 // Cipher is a stateful instance of ChaCha20 using a particular key
x8, x9, x10, x11, x12, x13, x14, x15 := j8, j9, j10, j11, j12, j13, j14, j15 // and nonce. A *Cipher implements the cipher.Stream interface.
type Cipher struct {
key [8]uint32
counter uint32 // incremented after each block
nonce [3]uint32
buf [bufSize]byte // buffer for unused keystream bytes
len int // number of unused keystream bytes at end of buf
}
for i := 0; i < rounds; i += 2 { // New creates a new ChaCha20 stream cipher with the given key and nonce.
x0 += x4 // The initial counter value is set to 0.
x12 ^= x0 func New(key [8]uint32, nonce [3]uint32) *Cipher {
x12 = (x12 << 16) | (x12 >> (16)) return &Cipher{key: key, nonce: nonce}
x8 += x12 }
x4 ^= x8
x4 = (x4 << 12) | (x4 >> (20)) // ChaCha20 constants spelling "expand 32-byte k"
x0 += x4 const (
x12 ^= x0 j0 uint32 = 0x61707865
x12 = (x12 << 8) | (x12 >> (24)) j1 uint32 = 0x3320646e
x8 += x12 j2 uint32 = 0x79622d32
x4 ^= x8 j3 uint32 = 0x6b206574
x4 = (x4 << 7) | (x4 >> (25)) )
x1 += x5
x13 ^= x1 func quarterRound(a, b, c, d uint32) (uint32, uint32, uint32, uint32) {
x13 = (x13 << 16) | (x13 >> 16) a += b
x9 += x13 d ^= a
x5 ^= x9 d = (d << 16) | (d >> 16)
x5 = (x5 << 12) | (x5 >> 20) c += d
x1 += x5 b ^= c
x13 ^= x1 b = (b << 12) | (b >> 20)
x13 = (x13 << 8) | (x13 >> 24) a += b
x9 += x13 d ^= a
x5 ^= x9 d = (d << 8) | (d >> 24)
x5 = (x5 << 7) | (x5 >> 25) c += d
x2 += x6 b ^= c
x14 ^= x2 b = (b << 7) | (b >> 25)
x14 = (x14 << 16) | (x14 >> 16) return a, b, c, d
x10 += x14 }
x6 ^= x10
x6 = (x6 << 12) | (x6 >> 20) // XORKeyStream XORs each byte in the given slice with a byte from the
x2 += x6 // cipher's key stream. Dst and src must overlap entirely or not at all.
x14 ^= x2 //
x14 = (x14 << 8) | (x14 >> 24) // If len(dst) < len(src), XORKeyStream will panic. It is acceptable
x10 += x14 // to pass a dst bigger than src, and in that case, XORKeyStream will
x6 ^= x10 // only update dst[:len(src)] and will not touch the rest of dst.
x6 = (x6 << 7) | (x6 >> 25) //
x3 += x7 // Multiple calls to XORKeyStream behave as if the concatenation of
x15 ^= x3 // the src buffers was passed in a single run. That is, Cipher
x15 = (x15 << 16) | (x15 >> 16) // maintains state and does not reset at each XORKeyStream call.
x11 += x15 func (s *Cipher) XORKeyStream(dst, src []byte) {
x7 ^= x11 if len(dst) < len(src) {
x7 = (x7 << 12) | (x7 >> 20) panic("chacha20: output smaller than input")
x3 += x7 }
x15 ^= x3 if subtle.InexactOverlap(dst[:len(src)], src) {
x15 = (x15 << 8) | (x15 >> 24) panic("chacha20: invalid buffer overlap")
x11 += x15 }
x7 ^= x11
x7 = (x7 << 7) | (x7 >> 25) // xor src with buffered keystream first
x0 += x5 if s.len != 0 {
x15 ^= x0 buf := s.buf[len(s.buf)-s.len:]
x15 = (x15 << 16) | (x15 >> 16) if len(src) < len(buf) {
x10 += x15 buf = buf[:len(src)]
x5 ^= x10 }
x5 = (x5 << 12) | (x5 >> 20) td, ts := dst[:len(buf)], src[:len(buf)] // BCE hint
x0 += x5 for i, b := range buf {
x15 ^= x0 td[i] = ts[i] ^ b
x15 = (x15 << 8) | (x15 >> 24) }
x10 += x15 s.len -= len(buf)
x5 ^= x10 if s.len != 0 {
x5 = (x5 << 7) | (x5 >> 25) return
x1 += x6 }
x12 ^= x1 s.buf = [len(s.buf)]byte{} // zero the empty buffer
x12 = (x12 << 16) | (x12 >> 16) src = src[len(buf):]
x11 += x12 dst = dst[len(buf):]
x6 ^= x11 }
x6 = (x6 << 12) | (x6 >> 20)
x1 += x6 if len(src) == 0 {
x12 ^= x1 return
x12 = (x12 << 8) | (x12 >> 24) }
x11 += x12 if haveAsm {
x6 ^= x11 if uint64(len(src))+uint64(s.counter)*64 > (1<<38)-64 {
x6 = (x6 << 7) | (x6 >> 25) panic("chacha20: counter overflow")
x2 += x7 }
x13 ^= x2 s.xorKeyStreamAsm(dst, src)
x13 = (x13 << 16) | (x13 >> 16) return
x8 += x13 }
x7 ^= x8
x7 = (x7 << 12) | (x7 >> 20) // set up a 64-byte buffer to pad out the final block if needed
x2 += x7 // (hoisted out of the main loop to avoid spills)
x13 ^= x2 rem := len(src) % 64 // length of final block
x13 = (x13 << 8) | (x13 >> 24) fin := len(src) - rem // index of final block
x8 += x13 if rem > 0 {
x7 ^= x8 copy(s.buf[len(s.buf)-64:], src[fin:])
x7 = (x7 << 7) | (x7 >> 25) }
x3 += x4
x14 ^= x3 // pre-calculate most of the first round
x14 = (x14 << 16) | (x14 >> 16) s1, s5, s9, s13 := quarterRound(j1, s.key[1], s.key[5], s.nonce[0])
x9 += x14 s2, s6, s10, s14 := quarterRound(j2, s.key[2], s.key[6], s.nonce[1])
x4 ^= x9 s3, s7, s11, s15 := quarterRound(j3, s.key[3], s.key[7], s.nonce[2])
x4 = (x4 << 12) | (x4 >> 20)
x3 += x4 n := len(src)
x14 ^= x3 src, dst = src[:n:n], dst[:n:n] // BCE hint
x14 = (x14 << 8) | (x14 >> 24) for i := 0; i < n; i += 64 {
x9 += x14 // calculate the remainder of the first round
x4 ^= x9 s0, s4, s8, s12 := quarterRound(j0, s.key[0], s.key[4], s.counter)
x4 = (x4 << 7) | (x4 >> 25)
// execute the second round
x0, x5, x10, x15 := quarterRound(s0, s5, s10, s15)
x1, x6, x11, x12 := quarterRound(s1, s6, s11, s12)
x2, x7, x8, x13 := quarterRound(s2, s7, s8, s13)
x3, x4, x9, x14 := quarterRound(s3, s4, s9, s14)
// execute the remaining 18 rounds
for i := 0; i < 9; i++ {
x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
} }
x0 += j0 x0 += j0
x1 += j1 x1 += j1
x2 += j2 x2 += j2
x3 += j3 x3 += j3
x4 += j4
x5 += j5
x6 += j6
x7 += j7
x8 += j8
x9 += j9
x10 += j10
x11 += j11
x12 += j12
x13 += j13
x14 += j14
x15 += j15
binary.LittleEndian.PutUint32(out[0:4], x0) x4 += s.key[0]
binary.LittleEndian.PutUint32(out[4:8], x1) x5 += s.key[1]
binary.LittleEndian.PutUint32(out[8:12], x2) x6 += s.key[2]
binary.LittleEndian.PutUint32(out[12:16], x3) x7 += s.key[3]
binary.LittleEndian.PutUint32(out[16:20], x4) x8 += s.key[4]
binary.LittleEndian.PutUint32(out[20:24], x5) x9 += s.key[5]
binary.LittleEndian.PutUint32(out[24:28], x6) x10 += s.key[6]
binary.LittleEndian.PutUint32(out[28:32], x7) x11 += s.key[7]
binary.LittleEndian.PutUint32(out[32:36], x8)
binary.LittleEndian.PutUint32(out[36:40], x9) x12 += s.counter
binary.LittleEndian.PutUint32(out[40:44], x10) x13 += s.nonce[0]
binary.LittleEndian.PutUint32(out[44:48], x11) x14 += s.nonce[1]
binary.LittleEndian.PutUint32(out[48:52], x12) x15 += s.nonce[2]
binary.LittleEndian.PutUint32(out[52:56], x13)
binary.LittleEndian.PutUint32(out[56:60], x14) // increment the counter
binary.LittleEndian.PutUint32(out[60:64], x15) s.counter += 1
if s.counter == 0 {
panic("chacha20: counter overflow")
}
// pad to 64 bytes if needed
in, out := src[i:], dst[i:]
if i == fin {
// src[fin:] has already been copied into s.buf before
// the main loop
in, out = s.buf[len(s.buf)-64:], s.buf[len(s.buf)-64:]
}
in, out = in[:64], out[:64] // BCE hint
// XOR the key stream with the source and write out the result
xor(out[0:], in[0:], x0)
xor(out[4:], in[4:], x1)
xor(out[8:], in[8:], x2)
xor(out[12:], in[12:], x3)
xor(out[16:], in[16:], x4)
xor(out[20:], in[20:], x5)
xor(out[24:], in[24:], x6)
xor(out[28:], in[28:], x7)
xor(out[32:], in[32:], x8)
xor(out[36:], in[36:], x9)
xor(out[40:], in[40:], x10)
xor(out[44:], in[44:], x11)
xor(out[48:], in[48:], x12)
xor(out[52:], in[52:], x13)
xor(out[56:], in[56:], x14)
xor(out[60:], in[60:], x15)
}
// copy any trailing bytes out of the buffer and into dst
if rem != 0 {
s.len = 64 - rem
copy(dst[fin:], s.buf[len(s.buf)-64:])
}
}
// Advance discards bytes in the key stream until the next 64 byte block
// boundary is reached and updates the counter accordingly. If the key
// stream is already at a block boundary no bytes will be discarded and
// the counter will be unchanged.
func (s *Cipher) Advance() {
s.len -= s.len % 64
if s.len == 0 {
s.buf = [len(s.buf)]byte{}
}
} }
// XORKeyStream crypts bytes from in to out using the given key and counters. // XORKeyStream crypts bytes from in to out using the given key and counters.
// In and out must overlap entirely or not at all. Counter contains the raw // In and out must overlap entirely or not at all. Counter contains the raw
// ChaCha20 counter bytes (i.e. block counter followed by nonce). // ChaCha20 counter bytes (i.e. block counter followed by nonce).
func XORKeyStream(out, in []byte, counter *[16]byte, key *[32]byte) { func XORKeyStream(out, in []byte, counter *[16]byte, key *[32]byte) {
var block [64]byte s := Cipher{
var counterCopy [16]byte key: [8]uint32{
copy(counterCopy[:], counter[:]) binary.LittleEndian.Uint32(key[0:4]),
binary.LittleEndian.Uint32(key[4:8]),
for len(in) >= 64 { binary.LittleEndian.Uint32(key[8:12]),
core(&block, &counterCopy, key) binary.LittleEndian.Uint32(key[12:16]),
for i, x := range block { binary.LittleEndian.Uint32(key[16:20]),
out[i] = in[i] ^ x binary.LittleEndian.Uint32(key[20:24]),
binary.LittleEndian.Uint32(key[24:28]),
binary.LittleEndian.Uint32(key[28:32]),
},
nonce: [3]uint32{
binary.LittleEndian.Uint32(counter[4:8]),
binary.LittleEndian.Uint32(counter[8:12]),
binary.LittleEndian.Uint32(counter[12:16]),
},
counter: binary.LittleEndian.Uint32(counter[0:4]),
} }
u := uint32(1) s.XORKeyStream(out, in)
for i := 0; i < 4; i++ {
u += uint32(counterCopy[i])
counterCopy[i] = byte(u)
u >>= 8
}
in = in[64:]
out = out[64:]
} }
if len(in) > 0 { // HChaCha20 uses the ChaCha20 core to generate a derived key from a key and a
core(&block, &counterCopy, key) // nonce. It should only be used as part of the XChaCha20 construction.
for i, v := range in { func HChaCha20(key *[8]uint32, nonce *[4]uint32) [8]uint32 {
out[i] = v ^ block[i] x0, x1, x2, x3 := j0, j1, j2, j3
} x4, x5, x6, x7 := key[0], key[1], key[2], key[3]
x8, x9, x10, x11 := key[4], key[5], key[6], key[7]
x12, x13, x14, x15 := nonce[0], nonce[1], nonce[2], nonce[3]
for i := 0; i < 10; i++ {
x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
} }
var out [8]uint32
out[0], out[1], out[2], out[3] = x0, x1, x2, x3
out[4], out[5], out[6], out[7] = x12, x13, x14, x15
return out
} }

View File

@ -0,0 +1,16 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !s390x gccgo appengine
package chacha20
const (
bufSize = 64
haveAsm = false
)
func (*Cipher) xorKeyStreamAsm(dst, src []byte) {
panic("not implemented")
}

View File

@ -0,0 +1,30 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build s390x,!gccgo,!appengine
package chacha20
var haveAsm = hasVectorFacility()
const bufSize = 256
// hasVectorFacility reports whether the machine supports the vector
// facility (vx).
// Implementation in asm_s390x.s.
func hasVectorFacility() bool
// xorKeyStreamVX is an assembly implementation of XORKeyStream. It must only
// be called when the vector facility is available.
// Implementation in asm_s390x.s.
//go:noescape
func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32, buf *[256]byte, len *int)
func (c *Cipher) xorKeyStreamAsm(dst, src []byte) {
xorKeyStreamVX(dst, src, &c.key, &c.nonce, &c.counter, &c.buf, &c.len)
}
// EXRL targets, DO NOT CALL!
func mvcSrcToBuf()
func mvcBufToDst()

View File

@ -0,0 +1,283 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build s390x,!gccgo,!appengine
#include "go_asm.h"
#include "textflag.h"
// This is an implementation of the ChaCha20 encryption algorithm as
// specified in RFC 7539. It uses vector instructions to compute
// 4 keystream blocks in parallel (256 bytes) which are then XORed
// with the bytes in the input slice.
GLOBL ·constants<>(SB), RODATA|NOPTR, $32
// BSWAP: swap bytes in each 4-byte element
DATA ·constants<>+0x00(SB)/4, $0x03020100
DATA ·constants<>+0x04(SB)/4, $0x07060504
DATA ·constants<>+0x08(SB)/4, $0x0b0a0908
DATA ·constants<>+0x0c(SB)/4, $0x0f0e0d0c
// J0: [j0, j1, j2, j3]
DATA ·constants<>+0x10(SB)/4, $0x61707865
DATA ·constants<>+0x14(SB)/4, $0x3320646e
DATA ·constants<>+0x18(SB)/4, $0x79622d32
DATA ·constants<>+0x1c(SB)/4, $0x6b206574
// EXRL targets:
TEXT ·mvcSrcToBuf(SB), NOFRAME|NOSPLIT, $0
MVC $1, (R1), (R8)
RET
TEXT ·mvcBufToDst(SB), NOFRAME|NOSPLIT, $0
MVC $1, (R8), (R9)
RET
#define BSWAP V5
#define J0 V6
#define KEY0 V7
#define KEY1 V8
#define NONCE V9
#define CTR V10
#define M0 V11
#define M1 V12
#define M2 V13
#define M3 V14
#define INC V15
#define X0 V16
#define X1 V17
#define X2 V18
#define X3 V19
#define X4 V20
#define X5 V21
#define X6 V22
#define X7 V23
#define X8 V24
#define X9 V25
#define X10 V26
#define X11 V27
#define X12 V28
#define X13 V29
#define X14 V30
#define X15 V31
#define NUM_ROUNDS 20
#define ROUND4(a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3) \
VAF a1, a0, a0 \
VAF b1, b0, b0 \
VAF c1, c0, c0 \
VAF d1, d0, d0 \
VX a0, a2, a2 \
VX b0, b2, b2 \
VX c0, c2, c2 \
VX d0, d2, d2 \
VERLLF $16, a2, a2 \
VERLLF $16, b2, b2 \
VERLLF $16, c2, c2 \
VERLLF $16, d2, d2 \
VAF a2, a3, a3 \
VAF b2, b3, b3 \
VAF c2, c3, c3 \
VAF d2, d3, d3 \
VX a3, a1, a1 \
VX b3, b1, b1 \
VX c3, c1, c1 \
VX d3, d1, d1 \
VERLLF $12, a1, a1 \
VERLLF $12, b1, b1 \
VERLLF $12, c1, c1 \
VERLLF $12, d1, d1 \
VAF a1, a0, a0 \
VAF b1, b0, b0 \
VAF c1, c0, c0 \
VAF d1, d0, d0 \
VX a0, a2, a2 \
VX b0, b2, b2 \
VX c0, c2, c2 \
VX d0, d2, d2 \
VERLLF $8, a2, a2 \
VERLLF $8, b2, b2 \
VERLLF $8, c2, c2 \
VERLLF $8, d2, d2 \
VAF a2, a3, a3 \
VAF b2, b3, b3 \
VAF c2, c3, c3 \
VAF d2, d3, d3 \
VX a3, a1, a1 \
VX b3, b1, b1 \
VX c3, c1, c1 \
VX d3, d1, d1 \
VERLLF $7, a1, a1 \
VERLLF $7, b1, b1 \
VERLLF $7, c1, c1 \
VERLLF $7, d1, d1
#define PERMUTE(mask, v0, v1, v2, v3) \
VPERM v0, v0, mask, v0 \
VPERM v1, v1, mask, v1 \
VPERM v2, v2, mask, v2 \
VPERM v3, v3, mask, v3
#define ADDV(x, v0, v1, v2, v3) \
VAF x, v0, v0 \
VAF x, v1, v1 \
VAF x, v2, v2 \
VAF x, v3, v3
#define XORV(off, dst, src, v0, v1, v2, v3) \
VLM off(src), M0, M3 \
PERMUTE(BSWAP, v0, v1, v2, v3) \
VX v0, M0, M0 \
VX v1, M1, M1 \
VX v2, M2, M2 \
VX v3, M3, M3 \
VSTM M0, M3, off(dst)
#define SHUFFLE(a, b, c, d, t, u, v, w) \
VMRHF a, c, t \ // t = {a[0], c[0], a[1], c[1]}
VMRHF b, d, u \ // u = {b[0], d[0], b[1], d[1]}
VMRLF a, c, v \ // v = {a[2], c[2], a[3], c[3]}
VMRLF b, d, w \ // w = {b[2], d[2], b[3], d[3]}
VMRHF t, u, a \ // a = {a[0], b[0], c[0], d[0]}
VMRLF t, u, b \ // b = {a[1], b[1], c[1], d[1]}
VMRHF v, w, c \ // c = {a[2], b[2], c[2], d[2]}
VMRLF v, w, d // d = {a[3], b[3], c[3], d[3]}
// func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32, buf *[256]byte, len *int)
TEXT ·xorKeyStreamVX(SB), NOSPLIT, $0
MOVD $·constants<>(SB), R1
MOVD dst+0(FP), R2 // R2=&dst[0]
LMG src+24(FP), R3, R4 // R3=&src[0] R4=len(src)
MOVD key+48(FP), R5 // R5=key
MOVD nonce+56(FP), R6 // R6=nonce
MOVD counter+64(FP), R7 // R7=counter
MOVD buf+72(FP), R8 // R8=buf
MOVD len+80(FP), R9 // R9=len
// load BSWAP and J0
VLM (R1), BSWAP, J0
// set up tail buffer
ADD $-1, R4, R12
MOVBZ R12, R12
CMPUBEQ R12, $255, aligned
MOVD R4, R1
AND $~255, R1
MOVD $(R3)(R1*1), R1
EXRL $·mvcSrcToBuf(SB), R12
MOVD $255, R0
SUB R12, R0
MOVD R0, (R9) // update len
aligned:
// setup
MOVD $95, R0
VLM (R5), KEY0, KEY1
VLL R0, (R6), NONCE
VZERO M0
VLEIB $7, $32, M0
VSRLB M0, NONCE, NONCE
// initialize counter values
VLREPF (R7), CTR
VZERO INC
VLEIF $1, $1, INC
VLEIF $2, $2, INC
VLEIF $3, $3, INC
VAF INC, CTR, CTR
VREPIF $4, INC
chacha:
VREPF $0, J0, X0
VREPF $1, J0, X1
VREPF $2, J0, X2
VREPF $3, J0, X3
VREPF $0, KEY0, X4
VREPF $1, KEY0, X5
VREPF $2, KEY0, X6
VREPF $3, KEY0, X7
VREPF $0, KEY1, X8
VREPF $1, KEY1, X9
VREPF $2, KEY1, X10
VREPF $3, KEY1, X11
VLR CTR, X12
VREPF $1, NONCE, X13
VREPF $2, NONCE, X14
VREPF $3, NONCE, X15
MOVD $(NUM_ROUNDS/2), R1
loop:
ROUND4(X0, X4, X12, X8, X1, X5, X13, X9, X2, X6, X14, X10, X3, X7, X15, X11)
ROUND4(X0, X5, X15, X10, X1, X6, X12, X11, X2, X7, X13, X8, X3, X4, X14, X9)
ADD $-1, R1
BNE loop
// decrement length
ADD $-256, R4
BLT tail
continue:
// rearrange vectors
SHUFFLE(X0, X1, X2, X3, M0, M1, M2, M3)
ADDV(J0, X0, X1, X2, X3)
SHUFFLE(X4, X5, X6, X7, M0, M1, M2, M3)
ADDV(KEY0, X4, X5, X6, X7)
SHUFFLE(X8, X9, X10, X11, M0, M1, M2, M3)
ADDV(KEY1, X8, X9, X10, X11)
VAF CTR, X12, X12
SHUFFLE(X12, X13, X14, X15, M0, M1, M2, M3)
ADDV(NONCE, X12, X13, X14, X15)
// increment counters
VAF INC, CTR, CTR
// xor keystream with plaintext
XORV(0*64, R2, R3, X0, X4, X8, X12)
XORV(1*64, R2, R3, X1, X5, X9, X13)
XORV(2*64, R2, R3, X2, X6, X10, X14)
XORV(3*64, R2, R3, X3, X7, X11, X15)
// increment pointers
MOVD $256(R2), R2
MOVD $256(R3), R3
CMPBNE R4, $0, chacha
CMPUBEQ R12, $255, return
EXRL $·mvcBufToDst(SB), R12 // len was updated during setup
return:
VSTEF $0, CTR, (R7)
RET
tail:
MOVD R2, R9
MOVD R8, R2
MOVD R8, R3
MOVD $0, R4
JMP continue
// func hasVectorFacility() bool
TEXT ·hasVectorFacility(SB), NOSPLIT, $24-1
MOVD $x-24(SP), R1
XC $24, 0(R1), 0(R1) // clear the storage
MOVD $2, R0 // R0 is the number of double words stored -1
WORD $0xB2B01000 // STFLE 0(R1)
XOR R0, R0 // reset the value of R0
MOVBZ z-8(SP), R1
AND $0x40, R1
BEQ novector
vectorinstalled:
// check if the vector instruction has been enabled
VLEIB $0, $0xF, V16
VLGVB $0, V16, R1
CMPBNE R1, $0xF, novector
MOVB $1, ret+0(FP) // have vx
RET
novector:
MOVB $0, ret+0(FP) // no vx
RET

43
vendor/golang.org/x/crypto/internal/chacha20/xor.go generated vendored Normal file
View File

@ -0,0 +1,43 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found src the LICENSE file.
package chacha20
import (
"runtime"
)
// Platforms that have fast unaligned 32-bit little endian accesses.
const unaligned = runtime.GOARCH == "386" ||
runtime.GOARCH == "amd64" ||
runtime.GOARCH == "arm64" ||
runtime.GOARCH == "ppc64le" ||
runtime.GOARCH == "s390x"
// xor reads a little endian uint32 from src, XORs it with u and
// places the result in little endian byte order in dst.
func xor(dst, src []byte, u uint32) {
_, _ = src[3], dst[3] // eliminate bounds checks
if unaligned {
// The compiler should optimize this code into
// 32-bit unaligned little endian loads and stores.
// TODO: delete once the compiler does a reliably
// good job with the generic code below.
// See issue #25111 for more details.
v := uint32(src[0])
v |= uint32(src[1]) << 8
v |= uint32(src[2]) << 16
v |= uint32(src[3]) << 24
v ^= u
dst[0] = byte(v)
dst[1] = byte(v >> 8)
dst[2] = byte(v >> 16)
dst[3] = byte(v >> 24)
} else {
dst[0] = src[0] ^ byte(u)
dst[1] = src[1] ^ byte(u>>8)
dst[2] = src[2] ^ byte(u>>16)
dst[3] = src[3] ^ byte(u>>24)
}
}

32
vendor/golang.org/x/crypto/internal/subtle/aliasing.go generated vendored Normal file
View File

@ -0,0 +1,32 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly.
package subtle // import "golang.org/x/crypto/internal/subtle"
import "unsafe"
// AnyOverlap reports whether x and y share memory at any (not necessarily
// corresponding) index. The memory beyond the slice length is ignored.
func AnyOverlap(x, y []byte) bool {
return len(x) > 0 && len(y) > 0 &&
uintptr(unsafe.Pointer(&x[0])) <= uintptr(unsafe.Pointer(&y[len(y)-1])) &&
uintptr(unsafe.Pointer(&y[0])) <= uintptr(unsafe.Pointer(&x[len(x)-1]))
}
// InexactOverlap reports whether x and y share memory at any non-corresponding
// index. The memory beyond the slice length is ignored. Note that x and y can
// have different lengths and still not have any inexact overlap.
//
// InexactOverlap can be used to implement the requirements of the crypto/cipher
// AEAD, Block, BlockMode and Stream interfaces.
func InexactOverlap(x, y []byte) bool {
if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] {
return false
}
return AnyOverlap(x, y)
}

View File

@ -0,0 +1,35 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine
// Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly.
package subtle // import "golang.org/x/crypto/internal/subtle"
// This is the Google App Engine standard variant based on reflect
// because the unsafe package and cgo are disallowed.
import "reflect"
// AnyOverlap reports whether x and y share memory at any (not necessarily
// corresponding) index. The memory beyond the slice length is ignored.
func AnyOverlap(x, y []byte) bool {
return len(x) > 0 && len(y) > 0 &&
reflect.ValueOf(&x[0]).Pointer() <= reflect.ValueOf(&y[len(y)-1]).Pointer() &&
reflect.ValueOf(&y[0]).Pointer() <= reflect.ValueOf(&x[len(x)-1]).Pointer()
}
// InexactOverlap reports whether x and y share memory at any non-corresponding
// index. The memory beyond the slice length is ignored. Note that x and y can
// have different lengths and still not have any inexact overlap.
//
// InexactOverlap can be used to implement the requirements of the crypto/cipher
// AEAD, Block, BlockMode and Stream interfaces.
func InexactOverlap(x, y []byte) bool {
if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] {
return false
}
return AnyOverlap(x, y)
}

View File

@ -35,6 +35,7 @@ This package is interoperable with NaCl: https://nacl.cr.yp.to/secretbox.html.
package secretbox // import "golang.org/x/crypto/nacl/secretbox" package secretbox // import "golang.org/x/crypto/nacl/secretbox"
import ( import (
"golang.org/x/crypto/internal/subtle"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.org/x/crypto/salsa20/salsa" "golang.org/x/crypto/salsa20/salsa"
) )
@ -87,6 +88,9 @@ func Seal(out, message []byte, nonce *[24]byte, key *[32]byte) []byte {
copy(poly1305Key[:], firstBlock[:]) copy(poly1305Key[:], firstBlock[:])
ret, out := sliceForAppend(out, len(message)+poly1305.TagSize) ret, out := sliceForAppend(out, len(message)+poly1305.TagSize)
if subtle.AnyOverlap(out, message) {
panic("nacl: invalid buffer overlap")
}
// We XOR up to 32 bytes of message with the keystream generated from // We XOR up to 32 bytes of message with the keystream generated from
// the first block. // the first block.
@ -118,7 +122,7 @@ func Seal(out, message []byte, nonce *[24]byte, key *[32]byte) []byte {
// Open authenticates and decrypts a box produced by Seal and appends the // Open authenticates and decrypts a box produced by Seal and appends the
// message to out, which must not overlap box. The output will be Overhead // message to out, which must not overlap box. The output will be Overhead
// bytes smaller than box. // bytes smaller than box.
func Open(out []byte, box []byte, nonce *[24]byte, key *[32]byte) ([]byte, bool) { func Open(out, box []byte, nonce *[24]byte, key *[32]byte) ([]byte, bool) {
if len(box) < Overhead { if len(box) < Overhead {
return nil, false return nil, false
} }
@ -143,6 +147,9 @@ func Open(out []byte, box []byte, nonce *[24]byte, key *[32]byte) ([]byte, bool)
} }
ret, out := sliceForAppend(out, len(box)-Overhead) ret, out := sliceForAppend(out, len(box)-Overhead)
if subtle.AnyOverlap(out, box) {
panic("nacl: invalid buffer overlap")
}
// We XOR up to 32 bytes of box with the keystream generated from // We XOR up to 32 bytes of box with the keystream generated from
// the first block. // the first block.

14
vendor/golang.org/x/crypto/poly1305/sum_noasm.go generated vendored Normal file
View File

@ -0,0 +1,14 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build s390x,!go1.11 !arm,!amd64,!s390x gccgo appengine nacl
package poly1305
// Sum generates an authenticator for msg using a one-time key and puts the
// 16-byte result into out. Authenticating two different messages with the same
// key allows an attacker to forge messages at will.
func Sum(out *[TagSize]byte, msg []byte, key *[32]byte) {
sumGeneric(out, msg, key)
}

View File

@ -2,16 +2,14 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build !amd64,!arm gccgo appengine nacl
package poly1305 package poly1305
import "encoding/binary" import "encoding/binary"
// Sum generates an authenticator for msg using a one-time key and puts the // sumGeneric generates an authenticator for msg using a one-time key and
// 16-byte result into out. Authenticating two different messages with the same // puts the 16-byte result into out. This is the generic implementation of
// key allows an attacker to forge messages at will. // Sum and should be called if no assembly implementation is available.
func Sum(out *[TagSize]byte, msg []byte, key *[32]byte) { func sumGeneric(out *[TagSize]byte, msg []byte, key *[32]byte) {
var ( var (
h0, h1, h2, h3, h4 uint32 // the hash accumulators h0, h1, h2, h3, h4 uint32 // the hash accumulators
r0, r1, r2, r3, r4 uint64 // the r part of the key r0, r1, r2, r3, r4 uint64 // the r part of the key

49
vendor/golang.org/x/crypto/poly1305/sum_s390x.go generated vendored Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build s390x,go1.11,!gccgo,!appengine
package poly1305
// hasVectorFacility reports whether the machine supports
// the vector facility (vx).
func hasVectorFacility() bool
// hasVMSLFacility reports whether the machine supports
// Vector Multiply Sum Logical (VMSL).
func hasVMSLFacility() bool
var hasVX = hasVectorFacility()
var hasVMSL = hasVMSLFacility()
// poly1305vx is an assembly implementation of Poly1305 that uses vector
// instructions. It must only be called if the vector facility (vx) is
// available.
//go:noescape
func poly1305vx(out *[16]byte, m *byte, mlen uint64, key *[32]byte)
// poly1305vmsl is an assembly implementation of Poly1305 that uses vector
// instructions, including VMSL. It must only be called if the vector facility (vx) is
// available and if VMSL is supported.
//go:noescape
func poly1305vmsl(out *[16]byte, m *byte, mlen uint64, key *[32]byte)
// Sum generates an authenticator for m using a one-time key and puts the
// 16-byte result into out. Authenticating two different messages with the same
// key allows an attacker to forge messages at will.
func Sum(out *[16]byte, m []byte, key *[32]byte) {
if hasVX {
var mPtr *byte
if len(m) > 0 {
mPtr = &m[0]
}
if hasVMSL && len(m) > 256 {
poly1305vmsl(out, mPtr, uint64(len(m)), key)
} else {
poly1305vx(out, mPtr, uint64(len(m)), key)
}
} else {
sumGeneric(out, m, key)
}
}

400
vendor/golang.org/x/crypto/poly1305/sum_s390x.s generated vendored Normal file
View File

@ -0,0 +1,400 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build s390x,go1.11,!gccgo,!appengine
#include "textflag.h"
// Implementation of Poly1305 using the vector facility (vx).
// constants
#define MOD26 V0
#define EX0 V1
#define EX1 V2
#define EX2 V3
// temporaries
#define T_0 V4
#define T_1 V5
#define T_2 V6
#define T_3 V7
#define T_4 V8
// key (r)
#define R_0 V9
#define R_1 V10
#define R_2 V11
#define R_3 V12
#define R_4 V13
#define R5_1 V14
#define R5_2 V15
#define R5_3 V16
#define R5_4 V17
#define RSAVE_0 R5
#define RSAVE_1 R6
#define RSAVE_2 R7
#define RSAVE_3 R8
#define RSAVE_4 R9
#define R5SAVE_1 V28
#define R5SAVE_2 V29
#define R5SAVE_3 V30
#define R5SAVE_4 V31
// message block
#define F_0 V18
#define F_1 V19
#define F_2 V20
#define F_3 V21
#define F_4 V22
// accumulator
#define H_0 V23
#define H_1 V24
#define H_2 V25
#define H_3 V26
#define H_4 V27
GLOBL ·keyMask<>(SB), RODATA, $16
DATA ·keyMask<>+0(SB)/8, $0xffffff0ffcffff0f
DATA ·keyMask<>+8(SB)/8, $0xfcffff0ffcffff0f
GLOBL ·bswapMask<>(SB), RODATA, $16
DATA ·bswapMask<>+0(SB)/8, $0x0f0e0d0c0b0a0908
DATA ·bswapMask<>+8(SB)/8, $0x0706050403020100
GLOBL ·constants<>(SB), RODATA, $64
// MOD26
DATA ·constants<>+0(SB)/8, $0x3ffffff
DATA ·constants<>+8(SB)/8, $0x3ffffff
// EX0
DATA ·constants<>+16(SB)/8, $0x0006050403020100
DATA ·constants<>+24(SB)/8, $0x1016151413121110
// EX1
DATA ·constants<>+32(SB)/8, $0x060c0b0a09080706
DATA ·constants<>+40(SB)/8, $0x161c1b1a19181716
// EX2
DATA ·constants<>+48(SB)/8, $0x0d0d0d0d0d0f0e0d
DATA ·constants<>+56(SB)/8, $0x1d1d1d1d1d1f1e1d
// h = (f*g) % (2**130-5) [partial reduction]
#define MULTIPLY(f0, f1, f2, f3, f4, g0, g1, g2, g3, g4, g51, g52, g53, g54, h0, h1, h2, h3, h4) \
VMLOF f0, g0, h0 \
VMLOF f0, g1, h1 \
VMLOF f0, g2, h2 \
VMLOF f0, g3, h3 \
VMLOF f0, g4, h4 \
VMLOF f1, g54, T_0 \
VMLOF f1, g0, T_1 \
VMLOF f1, g1, T_2 \
VMLOF f1, g2, T_3 \
VMLOF f1, g3, T_4 \
VMALOF f2, g53, h0, h0 \
VMALOF f2, g54, h1, h1 \
VMALOF f2, g0, h2, h2 \
VMALOF f2, g1, h3, h3 \
VMALOF f2, g2, h4, h4 \
VMALOF f3, g52, T_0, T_0 \
VMALOF f3, g53, T_1, T_1 \
VMALOF f3, g54, T_2, T_2 \
VMALOF f3, g0, T_3, T_3 \
VMALOF f3, g1, T_4, T_4 \
VMALOF f4, g51, h0, h0 \
VMALOF f4, g52, h1, h1 \
VMALOF f4, g53, h2, h2 \
VMALOF f4, g54, h3, h3 \
VMALOF f4, g0, h4, h4 \
VAG T_0, h0, h0 \
VAG T_1, h1, h1 \
VAG T_2, h2, h2 \
VAG T_3, h3, h3 \
VAG T_4, h4, h4
// carry h0->h1 h3->h4, h1->h2 h4->h0, h0->h1 h2->h3, h3->h4
#define REDUCE(h0, h1, h2, h3, h4) \
VESRLG $26, h0, T_0 \
VESRLG $26, h3, T_1 \
VN MOD26, h0, h0 \
VN MOD26, h3, h3 \
VAG T_0, h1, h1 \
VAG T_1, h4, h4 \
VESRLG $26, h1, T_2 \
VESRLG $26, h4, T_3 \
VN MOD26, h1, h1 \
VN MOD26, h4, h4 \
VESLG $2, T_3, T_4 \
VAG T_3, T_4, T_4 \
VAG T_2, h2, h2 \
VAG T_4, h0, h0 \
VESRLG $26, h2, T_0 \
VESRLG $26, h0, T_1 \
VN MOD26, h2, h2 \
VN MOD26, h0, h0 \
VAG T_0, h3, h3 \
VAG T_1, h1, h1 \
VESRLG $26, h3, T_2 \
VN MOD26, h3, h3 \
VAG T_2, h4, h4
// expand in0 into d[0] and in1 into d[1]
#define EXPAND(in0, in1, d0, d1, d2, d3, d4) \
VGBM $0x0707, d1 \ // d1=tmp
VPERM in0, in1, EX2, d4 \
VPERM in0, in1, EX0, d0 \
VPERM in0, in1, EX1, d2 \
VN d1, d4, d4 \
VESRLG $26, d0, d1 \
VESRLG $30, d2, d3 \
VESRLG $4, d2, d2 \
VN MOD26, d0, d0 \
VN MOD26, d1, d1 \
VN MOD26, d2, d2 \
VN MOD26, d3, d3
// pack h4:h0 into h1:h0 (no carry)
#define PACK(h0, h1, h2, h3, h4) \
VESLG $26, h1, h1 \
VESLG $26, h3, h3 \
VO h0, h1, h0 \
VO h2, h3, h2 \
VESLG $4, h2, h2 \
VLEIB $7, $48, h1 \
VSLB h1, h2, h2 \
VO h0, h2, h0 \
VLEIB $7, $104, h1 \
VSLB h1, h4, h3 \
VO h3, h0, h0 \
VLEIB $7, $24, h1 \
VSRLB h1, h4, h1
// if h > 2**130-5 then h -= 2**130-5
#define MOD(h0, h1, t0, t1, t2) \
VZERO t0 \
VLEIG $1, $5, t0 \
VACCQ h0, t0, t1 \
VAQ h0, t0, t0 \
VONE t2 \
VLEIG $1, $-4, t2 \
VAQ t2, t1, t1 \
VACCQ h1, t1, t1 \
VONE t2 \
VAQ t2, t1, t1 \
VN h0, t1, t2 \
VNC t0, t1, t1 \
VO t1, t2, h0
// func poly1305vx(out *[16]byte, m *byte, mlen uint64, key *[32]key)
TEXT ·poly1305vx(SB), $0-32
// This code processes up to 2 blocks (32 bytes) per iteration
// using the algorithm described in:
// NEON crypto, Daniel J. Bernstein & Peter Schwabe
// https://cryptojedi.org/papers/neoncrypto-20120320.pdf
LMG out+0(FP), R1, R4 // R1=out, R2=m, R3=mlen, R4=key
// load MOD26, EX0, EX1 and EX2
MOVD $·constants<>(SB), R5
VLM (R5), MOD26, EX2
// setup r
VL (R4), T_0
MOVD $·keyMask<>(SB), R6
VL (R6), T_1
VN T_0, T_1, T_0
EXPAND(T_0, T_0, R_0, R_1, R_2, R_3, R_4)
// setup r*5
VLEIG $0, $5, T_0
VLEIG $1, $5, T_0
// store r (for final block)
VMLOF T_0, R_1, R5SAVE_1
VMLOF T_0, R_2, R5SAVE_2
VMLOF T_0, R_3, R5SAVE_3
VMLOF T_0, R_4, R5SAVE_4
VLGVG $0, R_0, RSAVE_0
VLGVG $0, R_1, RSAVE_1
VLGVG $0, R_2, RSAVE_2
VLGVG $0, R_3, RSAVE_3
VLGVG $0, R_4, RSAVE_4
// skip r**2 calculation
CMPBLE R3, $16, skip
// calculate r**2
MULTIPLY(R_0, R_1, R_2, R_3, R_4, R_0, R_1, R_2, R_3, R_4, R5SAVE_1, R5SAVE_2, R5SAVE_3, R5SAVE_4, H_0, H_1, H_2, H_3, H_4)
REDUCE(H_0, H_1, H_2, H_3, H_4)
VLEIG $0, $5, T_0
VLEIG $1, $5, T_0
VMLOF T_0, H_1, R5_1
VMLOF T_0, H_2, R5_2
VMLOF T_0, H_3, R5_3
VMLOF T_0, H_4, R5_4
VLR H_0, R_0
VLR H_1, R_1
VLR H_2, R_2
VLR H_3, R_3
VLR H_4, R_4
// initialize h
VZERO H_0
VZERO H_1
VZERO H_2
VZERO H_3
VZERO H_4
loop:
CMPBLE R3, $32, b2
VLM (R2), T_0, T_1
SUB $32, R3
MOVD $32(R2), R2
EXPAND(T_0, T_1, F_0, F_1, F_2, F_3, F_4)
VLEIB $4, $1, F_4
VLEIB $12, $1, F_4
multiply:
VAG H_0, F_0, F_0
VAG H_1, F_1, F_1
VAG H_2, F_2, F_2
VAG H_3, F_3, F_3
VAG H_4, F_4, F_4
MULTIPLY(F_0, F_1, F_2, F_3, F_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, H_0, H_1, H_2, H_3, H_4)
REDUCE(H_0, H_1, H_2, H_3, H_4)
CMPBNE R3, $0, loop
finish:
// sum vectors
VZERO T_0
VSUMQG H_0, T_0, H_0
VSUMQG H_1, T_0, H_1
VSUMQG H_2, T_0, H_2
VSUMQG H_3, T_0, H_3
VSUMQG H_4, T_0, H_4
// h may be >= 2*(2**130-5) so we need to reduce it again
REDUCE(H_0, H_1, H_2, H_3, H_4)
// carry h1->h4
VESRLG $26, H_1, T_1
VN MOD26, H_1, H_1
VAQ T_1, H_2, H_2
VESRLG $26, H_2, T_2
VN MOD26, H_2, H_2
VAQ T_2, H_3, H_3
VESRLG $26, H_3, T_3
VN MOD26, H_3, H_3
VAQ T_3, H_4, H_4
// h is now < 2*(2**130-5)
// pack h into h1 (hi) and h0 (lo)
PACK(H_0, H_1, H_2, H_3, H_4)
// if h > 2**130-5 then h -= 2**130-5
MOD(H_0, H_1, T_0, T_1, T_2)
// h += s
MOVD $·bswapMask<>(SB), R5
VL (R5), T_1
VL 16(R4), T_0
VPERM T_0, T_0, T_1, T_0 // reverse bytes (to big)
VAQ T_0, H_0, H_0
VPERM H_0, H_0, T_1, H_0 // reverse bytes (to little)
VST H_0, (R1)
RET
b2:
CMPBLE R3, $16, b1
// 2 blocks remaining
SUB $17, R3
VL (R2), T_0
VLL R3, 16(R2), T_1
ADD $1, R3
MOVBZ $1, R0
CMPBEQ R3, $16, 2(PC)
VLVGB R3, R0, T_1
EXPAND(T_0, T_1, F_0, F_1, F_2, F_3, F_4)
CMPBNE R3, $16, 2(PC)
VLEIB $12, $1, F_4
VLEIB $4, $1, F_4
// setup [r²,r]
VLVGG $1, RSAVE_0, R_0
VLVGG $1, RSAVE_1, R_1
VLVGG $1, RSAVE_2, R_2
VLVGG $1, RSAVE_3, R_3
VLVGG $1, RSAVE_4, R_4
VPDI $0, R5_1, R5SAVE_1, R5_1
VPDI $0, R5_2, R5SAVE_2, R5_2
VPDI $0, R5_3, R5SAVE_3, R5_3
VPDI $0, R5_4, R5SAVE_4, R5_4
MOVD $0, R3
BR multiply
skip:
VZERO H_0
VZERO H_1
VZERO H_2
VZERO H_3
VZERO H_4
CMPBEQ R3, $0, finish
b1:
// 1 block remaining
SUB $1, R3
VLL R3, (R2), T_0
ADD $1, R3
MOVBZ $1, R0
CMPBEQ R3, $16, 2(PC)
VLVGB R3, R0, T_0
VZERO T_1
EXPAND(T_0, T_1, F_0, F_1, F_2, F_3, F_4)
CMPBNE R3, $16, 2(PC)
VLEIB $4, $1, F_4
VLEIG $1, $1, R_0
VZERO R_1
VZERO R_2
VZERO R_3
VZERO R_4
VZERO R5_1
VZERO R5_2
VZERO R5_3
VZERO R5_4
// setup [r, 1]
VLVGG $0, RSAVE_0, R_0
VLVGG $0, RSAVE_1, R_1
VLVGG $0, RSAVE_2, R_2
VLVGG $0, RSAVE_3, R_3
VLVGG $0, RSAVE_4, R_4
VPDI $0, R5SAVE_1, R5_1, R5_1
VPDI $0, R5SAVE_2, R5_2, R5_2
VPDI $0, R5SAVE_3, R5_3, R5_3
VPDI $0, R5SAVE_4, R5_4, R5_4
MOVD $0, R3
BR multiply
TEXT ·hasVectorFacility(SB), NOSPLIT, $24-1
MOVD $x-24(SP), R1
XC $24, 0(R1), 0(R1) // clear the storage
MOVD $2, R0 // R0 is the number of double words stored -1
WORD $0xB2B01000 // STFLE 0(R1)
XOR R0, R0 // reset the value of R0
MOVBZ z-8(SP), R1
AND $0x40, R1
BEQ novector
vectorinstalled:
// check if the vector instruction has been enabled
VLEIB $0, $0xF, V16
VLGVB $0, V16, R1
CMPBNE R1, $0xF, novector
MOVB $1, ret+0(FP) // have vx
RET
novector:
MOVB $0, ret+0(FP) // no vx
RET

931
vendor/golang.org/x/crypto/poly1305/sum_vmsl_s390x.s generated vendored Normal file
View File

@ -0,0 +1,931 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build s390x,go1.11,!gccgo,!appengine
#include "textflag.h"
// Implementation of Poly1305 using the vector facility (vx) and the VMSL instruction.
// constants
#define EX0 V1
#define EX1 V2
#define EX2 V3
// temporaries
#define T_0 V4
#define T_1 V5
#define T_2 V6
#define T_3 V7
#define T_4 V8
#define T_5 V9
#define T_6 V10
#define T_7 V11
#define T_8 V12
#define T_9 V13
#define T_10 V14
// r**2 & r**4
#define R_0 V15
#define R_1 V16
#define R_2 V17
#define R5_1 V18
#define R5_2 V19
// key (r)
#define RSAVE_0 R7
#define RSAVE_1 R8
#define RSAVE_2 R9
#define R5SAVE_1 R10
#define R5SAVE_2 R11
// message block
#define M0 V20
#define M1 V21
#define M2 V22
#define M3 V23
#define M4 V24
#define M5 V25
// accumulator
#define H0_0 V26
#define H1_0 V27
#define H2_0 V28
#define H0_1 V29
#define H1_1 V30
#define H2_1 V31
GLOBL ·keyMask<>(SB), RODATA, $16
DATA ·keyMask<>+0(SB)/8, $0xffffff0ffcffff0f
DATA ·keyMask<>+8(SB)/8, $0xfcffff0ffcffff0f
GLOBL ·bswapMask<>(SB), RODATA, $16
DATA ·bswapMask<>+0(SB)/8, $0x0f0e0d0c0b0a0908
DATA ·bswapMask<>+8(SB)/8, $0x0706050403020100
GLOBL ·constants<>(SB), RODATA, $48
// EX0
DATA ·constants<>+0(SB)/8, $0x18191a1b1c1d1e1f
DATA ·constants<>+8(SB)/8, $0x0000050403020100
// EX1
DATA ·constants<>+16(SB)/8, $0x18191a1b1c1d1e1f
DATA ·constants<>+24(SB)/8, $0x00000a0908070605
// EX2
DATA ·constants<>+32(SB)/8, $0x18191a1b1c1d1e1f
DATA ·constants<>+40(SB)/8, $0x0000000f0e0d0c0b
GLOBL ·c<>(SB), RODATA, $48
// EX0
DATA ·c<>+0(SB)/8, $0x0000050403020100
DATA ·c<>+8(SB)/8, $0x0000151413121110
// EX1
DATA ·c<>+16(SB)/8, $0x00000a0908070605
DATA ·c<>+24(SB)/8, $0x00001a1918171615
// EX2
DATA ·c<>+32(SB)/8, $0x0000000f0e0d0c0b
DATA ·c<>+40(SB)/8, $0x0000001f1e1d1c1b
GLOBL ·reduce<>(SB), RODATA, $32
// 44 bit
DATA ·reduce<>+0(SB)/8, $0x0
DATA ·reduce<>+8(SB)/8, $0xfffffffffff
// 42 bit
DATA ·reduce<>+16(SB)/8, $0x0
DATA ·reduce<>+24(SB)/8, $0x3ffffffffff
// h = (f*g) % (2**130-5) [partial reduction]
// uses T_0...T_9 temporary registers
// input: m02_0, m02_1, m02_2, m13_0, m13_1, m13_2, r_0, r_1, r_2, r5_1, r5_2, m4_0, m4_1, m4_2, m5_0, m5_1, m5_2
// temp: t0, t1, t2, t3, t4, t5, t6, t7, t8, t9
// output: m02_0, m02_1, m02_2, m13_0, m13_1, m13_2
#define MULTIPLY(m02_0, m02_1, m02_2, m13_0, m13_1, m13_2, r_0, r_1, r_2, r5_1, r5_2, m4_0, m4_1, m4_2, m5_0, m5_1, m5_2, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9) \
\ // Eliminate the dependency for the last 2 VMSLs
VMSLG m02_0, r_2, m4_2, m4_2 \
VMSLG m13_0, r_2, m5_2, m5_2 \ // 8 VMSLs pipelined
VMSLG m02_0, r_0, m4_0, m4_0 \
VMSLG m02_1, r5_2, V0, T_0 \
VMSLG m02_0, r_1, m4_1, m4_1 \
VMSLG m02_1, r_0, V0, T_1 \
VMSLG m02_1, r_1, V0, T_2 \
VMSLG m02_2, r5_1, V0, T_3 \
VMSLG m02_2, r5_2, V0, T_4 \
VMSLG m13_0, r_0, m5_0, m5_0 \
VMSLG m13_1, r5_2, V0, T_5 \
VMSLG m13_0, r_1, m5_1, m5_1 \
VMSLG m13_1, r_0, V0, T_6 \
VMSLG m13_1, r_1, V0, T_7 \
VMSLG m13_2, r5_1, V0, T_8 \
VMSLG m13_2, r5_2, V0, T_9 \
VMSLG m02_2, r_0, m4_2, m4_2 \
VMSLG m13_2, r_0, m5_2, m5_2 \
VAQ m4_0, T_0, m02_0 \
VAQ m4_1, T_1, m02_1 \
VAQ m5_0, T_5, m13_0 \
VAQ m5_1, T_6, m13_1 \
VAQ m02_0, T_3, m02_0 \
VAQ m02_1, T_4, m02_1 \
VAQ m13_0, T_8, m13_0 \
VAQ m13_1, T_9, m13_1 \
VAQ m4_2, T_2, m02_2 \
VAQ m5_2, T_7, m13_2 \
// SQUARE uses three limbs of r and r_2*5 to output square of r
// uses T_1, T_5 and T_7 temporary registers
// input: r_0, r_1, r_2, r5_2
// temp: TEMP0, TEMP1, TEMP2
// output: p0, p1, p2
#define SQUARE(r_0, r_1, r_2, r5_2, p0, p1, p2, TEMP0, TEMP1, TEMP2) \
VMSLG r_0, r_0, p0, p0 \
VMSLG r_1, r5_2, V0, TEMP0 \
VMSLG r_2, r5_2, p1, p1 \
VMSLG r_0, r_1, V0, TEMP1 \
VMSLG r_1, r_1, p2, p2 \
VMSLG r_0, r_2, V0, TEMP2 \
VAQ TEMP0, p0, p0 \
VAQ TEMP1, p1, p1 \
VAQ TEMP2, p2, p2 \
VAQ TEMP0, p0, p0 \
VAQ TEMP1, p1, p1 \
VAQ TEMP2, p2, p2 \
// carry h0->h1->h2->h0 || h3->h4->h5->h3
// uses T_2, T_4, T_5, T_7, T_8, T_9
// t6, t7, t8, t9, t10, t11
// input: h0, h1, h2, h3, h4, h5
// temp: t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11
// output: h0, h1, h2, h3, h4, h5
#define REDUCE(h0, h1, h2, h3, h4, h5, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) \
VLM (R12), t6, t7 \ // 44 and 42 bit clear mask
VLEIB $7, $0x28, t10 \ // 5 byte shift mask
VREPIB $4, t8 \ // 4 bit shift mask
VREPIB $2, t11 \ // 2 bit shift mask
VSRLB t10, h0, t0 \ // h0 byte shift
VSRLB t10, h1, t1 \ // h1 byte shift
VSRLB t10, h2, t2 \ // h2 byte shift
VSRLB t10, h3, t3 \ // h3 byte shift
VSRLB t10, h4, t4 \ // h4 byte shift
VSRLB t10, h5, t5 \ // h5 byte shift
VSRL t8, t0, t0 \ // h0 bit shift
VSRL t8, t1, t1 \ // h2 bit shift
VSRL t11, t2, t2 \ // h2 bit shift
VSRL t8, t3, t3 \ // h3 bit shift
VSRL t8, t4, t4 \ // h4 bit shift
VESLG $2, t2, t9 \ // h2 carry x5
VSRL t11, t5, t5 \ // h5 bit shift
VN t6, h0, h0 \ // h0 clear carry
VAQ t2, t9, t2 \ // h2 carry x5
VESLG $2, t5, t9 \ // h5 carry x5
VN t6, h1, h1 \ // h1 clear carry
VN t7, h2, h2 \ // h2 clear carry
VAQ t5, t9, t5 \ // h5 carry x5
VN t6, h3, h3 \ // h3 clear carry
VN t6, h4, h4 \ // h4 clear carry
VN t7, h5, h5 \ // h5 clear carry
VAQ t0, h1, h1 \ // h0->h1
VAQ t3, h4, h4 \ // h3->h4
VAQ t1, h2, h2 \ // h1->h2
VAQ t4, h5, h5 \ // h4->h5
VAQ t2, h0, h0 \ // h2->h0
VAQ t5, h3, h3 \ // h5->h3
VREPG $1, t6, t6 \ // 44 and 42 bit masks across both halves
VREPG $1, t7, t7 \
VSLDB $8, h0, h0, h0 \ // set up [h0/1/2, h3/4/5]
VSLDB $8, h1, h1, h1 \
VSLDB $8, h2, h2, h2 \
VO h0, h3, h3 \
VO h1, h4, h4 \
VO h2, h5, h5 \
VESRLG $44, h3, t0 \ // 44 bit shift right
VESRLG $44, h4, t1 \
VESRLG $42, h5, t2 \
VN t6, h3, h3 \ // clear carry bits
VN t6, h4, h4 \
VN t7, h5, h5 \
VESLG $2, t2, t9 \ // multiply carry by 5
VAQ t9, t2, t2 \
VAQ t0, h4, h4 \
VAQ t1, h5, h5 \
VAQ t2, h3, h3 \
// carry h0->h1->h2->h0
// input: h0, h1, h2
// temp: t0, t1, t2, t3, t4, t5, t6, t7, t8
// output: h0, h1, h2
#define REDUCE2(h0, h1, h2, t0, t1, t2, t3, t4, t5, t6, t7, t8) \
VLEIB $7, $0x28, t3 \ // 5 byte shift mask
VREPIB $4, t4 \ // 4 bit shift mask
VREPIB $2, t7 \ // 2 bit shift mask
VGBM $0x003F, t5 \ // mask to clear carry bits
VSRLB t3, h0, t0 \
VSRLB t3, h1, t1 \
VSRLB t3, h2, t2 \
VESRLG $4, t5, t5 \ // 44 bit clear mask
VSRL t4, t0, t0 \
VSRL t4, t1, t1 \
VSRL t7, t2, t2 \
VESRLG $2, t5, t6 \ // 42 bit clear mask
VESLG $2, t2, t8 \
VAQ t8, t2, t2 \
VN t5, h0, h0 \
VN t5, h1, h1 \
VN t6, h2, h2 \
VAQ t0, h1, h1 \
VAQ t1, h2, h2 \
VAQ t2, h0, h0 \
VSRLB t3, h0, t0 \
VSRLB t3, h1, t1 \
VSRLB t3, h2, t2 \
VSRL t4, t0, t0 \
VSRL t4, t1, t1 \
VSRL t7, t2, t2 \
VN t5, h0, h0 \
VN t5, h1, h1 \
VESLG $2, t2, t8 \
VN t6, h2, h2 \
VAQ t0, h1, h1 \
VAQ t8, t2, t2 \
VAQ t1, h2, h2 \
VAQ t2, h0, h0 \
// expands two message blocks into the lower halfs of the d registers
// moves the contents of the d registers into upper halfs
// input: in1, in2, d0, d1, d2, d3, d4, d5
// temp: TEMP0, TEMP1, TEMP2, TEMP3
// output: d0, d1, d2, d3, d4, d5
#define EXPACC(in1, in2, d0, d1, d2, d3, d4, d5, TEMP0, TEMP1, TEMP2, TEMP3) \
VGBM $0xff3f, TEMP0 \
VGBM $0xff1f, TEMP1 \
VESLG $4, d1, TEMP2 \
VESLG $4, d4, TEMP3 \
VESRLG $4, TEMP0, TEMP0 \
VPERM in1, d0, EX0, d0 \
VPERM in2, d3, EX0, d3 \
VPERM in1, d2, EX2, d2 \
VPERM in2, d5, EX2, d5 \
VPERM in1, TEMP2, EX1, d1 \
VPERM in2, TEMP3, EX1, d4 \
VN TEMP0, d0, d0 \
VN TEMP0, d3, d3 \
VESRLG $4, d1, d1 \
VESRLG $4, d4, d4 \
VN TEMP1, d2, d2 \
VN TEMP1, d5, d5 \
VN TEMP0, d1, d1 \
VN TEMP0, d4, d4 \
// expands one message block into the lower halfs of the d registers
// moves the contents of the d registers into upper halfs
// input: in, d0, d1, d2
// temp: TEMP0, TEMP1, TEMP2
// output: d0, d1, d2
#define EXPACC2(in, d0, d1, d2, TEMP0, TEMP1, TEMP2) \
VGBM $0xff3f, TEMP0 \
VESLG $4, d1, TEMP2 \
VGBM $0xff1f, TEMP1 \
VPERM in, d0, EX0, d0 \
VESRLG $4, TEMP0, TEMP0 \
VPERM in, d2, EX2, d2 \
VPERM in, TEMP2, EX1, d1 \
VN TEMP0, d0, d0 \
VN TEMP1, d2, d2 \
VESRLG $4, d1, d1 \
VN TEMP0, d1, d1 \
// pack h2:h0 into h1:h0 (no carry)
// input: h0, h1, h2
// output: h0, h1, h2
#define PACK(h0, h1, h2) \
VMRLG h1, h2, h2 \ // copy h1 to upper half h2
VESLG $44, h1, h1 \ // shift limb 1 44 bits, leaving 20
VO h0, h1, h0 \ // combine h0 with 20 bits from limb 1
VESRLG $20, h2, h1 \ // put top 24 bits of limb 1 into h1
VLEIG $1, $0, h1 \ // clear h2 stuff from lower half of h1
VO h0, h1, h0 \ // h0 now has 88 bits (limb 0 and 1)
VLEIG $0, $0, h2 \ // clear upper half of h2
VESRLG $40, h2, h1 \ // h1 now has upper two bits of result
VLEIB $7, $88, h1 \ // for byte shift (11 bytes)
VSLB h1, h2, h2 \ // shift h2 11 bytes to the left
VO h0, h2, h0 \ // combine h0 with 20 bits from limb 1
VLEIG $0, $0, h1 \ // clear upper half of h1
// if h > 2**130-5 then h -= 2**130-5
// input: h0, h1
// temp: t0, t1, t2
// output: h0
#define MOD(h0, h1, t0, t1, t2) \
VZERO t0 \
VLEIG $1, $5, t0 \
VACCQ h0, t0, t1 \
VAQ h0, t0, t0 \
VONE t2 \
VLEIG $1, $-4, t2 \
VAQ t2, t1, t1 \
VACCQ h1, t1, t1 \
VONE t2 \
VAQ t2, t1, t1 \
VN h0, t1, t2 \
VNC t0, t1, t1 \
VO t1, t2, h0 \
// func poly1305vmsl(out *[16]byte, m *byte, mlen uint64, key *[32]key)
TEXT ·poly1305vmsl(SB), $0-32
// This code processes 6 + up to 4 blocks (32 bytes) per iteration
// using the algorithm described in:
// NEON crypto, Daniel J. Bernstein & Peter Schwabe
// https://cryptojedi.org/papers/neoncrypto-20120320.pdf
// And as moddified for VMSL as described in
// Accelerating Poly1305 Cryptographic Message Authentication on the z14
// O'Farrell et al, CASCON 2017, p48-55
// https://ibm.ent.box.com/s/jf9gedj0e9d2vjctfyh186shaztavnht
LMG out+0(FP), R1, R4 // R1=out, R2=m, R3=mlen, R4=key
VZERO V0 // c
// load EX0, EX1 and EX2
MOVD $·constants<>(SB), R5
VLM (R5), EX0, EX2 // c
// setup r
VL (R4), T_0
MOVD $·keyMask<>(SB), R6
VL (R6), T_1
VN T_0, T_1, T_0
VZERO T_2 // limbs for r
VZERO T_3
VZERO T_4
EXPACC2(T_0, T_2, T_3, T_4, T_1, T_5, T_7)
// T_2, T_3, T_4: [0, r]
// setup r*20
VLEIG $0, $0, T_0
VLEIG $1, $20, T_0 // T_0: [0, 20]
VZERO T_5
VZERO T_6
VMSLG T_0, T_3, T_5, T_5
VMSLG T_0, T_4, T_6, T_6
// store r for final block in GR
VLGVG $1, T_2, RSAVE_0 // c
VLGVG $1, T_3, RSAVE_1 // c
VLGVG $1, T_4, RSAVE_2 // c
VLGVG $1, T_5, R5SAVE_1 // c
VLGVG $1, T_6, R5SAVE_2 // c
// initialize h
VZERO H0_0
VZERO H1_0
VZERO H2_0
VZERO H0_1
VZERO H1_1
VZERO H2_1
// initialize pointer for reduce constants
MOVD $·reduce<>(SB), R12
// calculate r**2 and 20*(r**2)
VZERO R_0
VZERO R_1
VZERO R_2
SQUARE(T_2, T_3, T_4, T_6, R_0, R_1, R_2, T_1, T_5, T_7)
REDUCE2(R_0, R_1, R_2, M0, M1, M2, M3, M4, R5_1, R5_2, M5, T_1)
VZERO R5_1
VZERO R5_2
VMSLG T_0, R_1, R5_1, R5_1
VMSLG T_0, R_2, R5_2, R5_2
// skip r**4 calculation if 3 blocks or less
CMPBLE R3, $48, b4
// calculate r**4 and 20*(r**4)
VZERO T_8
VZERO T_9
VZERO T_10
SQUARE(R_0, R_1, R_2, R5_2, T_8, T_9, T_10, T_1, T_5, T_7)
REDUCE2(T_8, T_9, T_10, M0, M1, M2, M3, M4, T_2, T_3, M5, T_1)
VZERO T_2
VZERO T_3
VMSLG T_0, T_9, T_2, T_2
VMSLG T_0, T_10, T_3, T_3
// put r**2 to the right and r**4 to the left of R_0, R_1, R_2
VSLDB $8, T_8, T_8, T_8
VSLDB $8, T_9, T_9, T_9
VSLDB $8, T_10, T_10, T_10
VSLDB $8, T_2, T_2, T_2
VSLDB $8, T_3, T_3, T_3
VO T_8, R_0, R_0
VO T_9, R_1, R_1
VO T_10, R_2, R_2
VO T_2, R5_1, R5_1
VO T_3, R5_2, R5_2
CMPBLE R3, $80, load // less than or equal to 5 blocks in message
// 6(or 5+1) blocks
SUB $81, R3
VLM (R2), M0, M4
VLL R3, 80(R2), M5
ADD $1, R3
MOVBZ $1, R0
CMPBGE R3, $16, 2(PC)
VLVGB R3, R0, M5
MOVD $96(R2), R2
EXPACC(M0, M1, H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_0, T_1, T_2, T_3)
EXPACC(M2, M3, H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_0, T_1, T_2, T_3)
VLEIB $2, $1, H2_0
VLEIB $2, $1, H2_1
VLEIB $10, $1, H2_0
VLEIB $10, $1, H2_1
VZERO M0
VZERO M1
VZERO M2
VZERO M3
VZERO T_4
VZERO T_10
EXPACC(M4, M5, M0, M1, M2, M3, T_4, T_10, T_0, T_1, T_2, T_3)
VLR T_4, M4
VLEIB $10, $1, M2
CMPBLT R3, $16, 2(PC)
VLEIB $10, $1, T_10
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, T_10, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M2, M3, M4, T_4, T_5, T_2, T_7, T_8, T_9)
VMRHG V0, H0_1, H0_0
VMRHG V0, H1_1, H1_0
VMRHG V0, H2_1, H2_0
VMRLG V0, H0_1, H0_1
VMRLG V0, H1_1, H1_1
VMRLG V0, H2_1, H2_1
SUB $16, R3
CMPBLE R3, $0, square
load:
// load EX0, EX1 and EX2
MOVD $·c<>(SB), R5
VLM (R5), EX0, EX2
loop:
CMPBLE R3, $64, add // b4 // last 4 or less blocks left
// next 4 full blocks
VLM (R2), M2, M5
SUB $64, R3
MOVD $64(R2), R2
REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, T_0, T_1, T_3, T_4, T_5, T_2, T_7, T_8, T_9)
// expacc in-lined to create [m2, m3] limbs
VGBM $0x3f3f, T_0 // 44 bit clear mask
VGBM $0x1f1f, T_1 // 40 bit clear mask
VPERM M2, M3, EX0, T_3
VESRLG $4, T_0, T_0 // 44 bit clear mask ready
VPERM M2, M3, EX1, T_4
VPERM M2, M3, EX2, T_5
VN T_0, T_3, T_3
VESRLG $4, T_4, T_4
VN T_1, T_5, T_5
VN T_0, T_4, T_4
VMRHG H0_1, T_3, H0_0
VMRHG H1_1, T_4, H1_0
VMRHG H2_1, T_5, H2_0
VMRLG H0_1, T_3, H0_1
VMRLG H1_1, T_4, H1_1
VMRLG H2_1, T_5, H2_1
VLEIB $10, $1, H2_0
VLEIB $10, $1, H2_1
VPERM M4, M5, EX0, T_3
VPERM M4, M5, EX1, T_4
VPERM M4, M5, EX2, T_5
VN T_0, T_3, T_3
VESRLG $4, T_4, T_4
VN T_1, T_5, T_5
VN T_0, T_4, T_4
VMRHG V0, T_3, M0
VMRHG V0, T_4, M1
VMRHG V0, T_5, M2
VMRLG V0, T_3, M3
VMRLG V0, T_4, M4
VMRLG V0, T_5, M5
VLEIB $10, $1, M2
VLEIB $10, $1, M5
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
CMPBNE R3, $0, loop
REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M3, M4, M5, T_4, T_5, T_2, T_7, T_8, T_9)
VMRHG V0, H0_1, H0_0
VMRHG V0, H1_1, H1_0
VMRHG V0, H2_1, H2_0
VMRLG V0, H0_1, H0_1
VMRLG V0, H1_1, H1_1
VMRLG V0, H2_1, H2_1
// load EX0, EX1, EX2
MOVD $·constants<>(SB), R5
VLM (R5), EX0, EX2
// sum vectors
VAQ H0_0, H0_1, H0_0
VAQ H1_0, H1_1, H1_0
VAQ H2_0, H2_1, H2_0
// h may be >= 2*(2**130-5) so we need to reduce it again
// M0...M4 are used as temps here
REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, T_9, T_10, H0_1, M5)
next: // carry h1->h2
VLEIB $7, $0x28, T_1
VREPIB $4, T_2
VGBM $0x003F, T_3
VESRLG $4, T_3
// byte shift
VSRLB T_1, H1_0, T_4
// bit shift
VSRL T_2, T_4, T_4
// clear h1 carry bits
VN T_3, H1_0, H1_0
// add carry
VAQ T_4, H2_0, H2_0
// h is now < 2*(2**130-5)
// pack h into h1 (hi) and h0 (lo)
PACK(H0_0, H1_0, H2_0)
// if h > 2**130-5 then h -= 2**130-5
MOD(H0_0, H1_0, T_0, T_1, T_2)
// h += s
MOVD $·bswapMask<>(SB), R5
VL (R5), T_1
VL 16(R4), T_0
VPERM T_0, T_0, T_1, T_0 // reverse bytes (to big)
VAQ T_0, H0_0, H0_0
VPERM H0_0, H0_0, T_1, H0_0 // reverse bytes (to little)
VST H0_0, (R1)
RET
add:
// load EX0, EX1, EX2
MOVD $·constants<>(SB), R5
VLM (R5), EX0, EX2
REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M3, M4, M5, T_4, T_5, T_2, T_7, T_8, T_9)
VMRHG V0, H0_1, H0_0
VMRHG V0, H1_1, H1_0
VMRHG V0, H2_1, H2_0
VMRLG V0, H0_1, H0_1
VMRLG V0, H1_1, H1_1
VMRLG V0, H2_1, H2_1
CMPBLE R3, $64, b4
b4:
CMPBLE R3, $48, b3 // 3 blocks or less
// 4(3+1) blocks remaining
SUB $49, R3
VLM (R2), M0, M2
VLL R3, 48(R2), M3
ADD $1, R3
MOVBZ $1, R0
CMPBEQ R3, $16, 2(PC)
VLVGB R3, R0, M3
MOVD $64(R2), R2
EXPACC(M0, M1, H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_0, T_1, T_2, T_3)
VLEIB $10, $1, H2_0
VLEIB $10, $1, H2_1
VZERO M0
VZERO M1
VZERO M4
VZERO M5
VZERO T_4
VZERO T_10
EXPACC(M2, M3, M0, M1, M4, M5, T_4, T_10, T_0, T_1, T_2, T_3)
VLR T_4, M2
VLEIB $10, $1, M4
CMPBNE R3, $16, 2(PC)
VLEIB $10, $1, T_10
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M4, M5, M2, T_10, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M3, M4, M5, T_4, T_5, T_2, T_7, T_8, T_9)
VMRHG V0, H0_1, H0_0
VMRHG V0, H1_1, H1_0
VMRHG V0, H2_1, H2_0
VMRLG V0, H0_1, H0_1
VMRLG V0, H1_1, H1_1
VMRLG V0, H2_1, H2_1
SUB $16, R3
CMPBLE R3, $0, square // this condition must always hold true!
b3:
CMPBLE R3, $32, b2
// 3 blocks remaining
// setup [r²,r]
VSLDB $8, R_0, R_0, R_0
VSLDB $8, R_1, R_1, R_1
VSLDB $8, R_2, R_2, R_2
VSLDB $8, R5_1, R5_1, R5_1
VSLDB $8, R5_2, R5_2, R5_2
VLVGG $1, RSAVE_0, R_0
VLVGG $1, RSAVE_1, R_1
VLVGG $1, RSAVE_2, R_2
VLVGG $1, R5SAVE_1, R5_1
VLVGG $1, R5SAVE_2, R5_2
// setup [h0, h1]
VSLDB $8, H0_0, H0_0, H0_0
VSLDB $8, H1_0, H1_0, H1_0
VSLDB $8, H2_0, H2_0, H2_0
VO H0_1, H0_0, H0_0
VO H1_1, H1_0, H1_0
VO H2_1, H2_0, H2_0
VZERO H0_1
VZERO H1_1
VZERO H2_1
VZERO M0
VZERO M1
VZERO M2
VZERO M3
VZERO M4
VZERO M5
// H*[r**2, r]
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, H0_1, H1_1, T_10, M5)
SUB $33, R3
VLM (R2), M0, M1
VLL R3, 32(R2), M2
ADD $1, R3
MOVBZ $1, R0
CMPBEQ R3, $16, 2(PC)
VLVGB R3, R0, M2
// H += m0
VZERO T_1
VZERO T_2
VZERO T_3
EXPACC2(M0, T_1, T_2, T_3, T_4, T_5, T_6)
VLEIB $10, $1, T_3
VAG H0_0, T_1, H0_0
VAG H1_0, T_2, H1_0
VAG H2_0, T_3, H2_0
VZERO M0
VZERO M3
VZERO M4
VZERO M5
VZERO T_10
// (H+m0)*r
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M3, M4, M5, V0, T_10, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE2(H0_0, H1_0, H2_0, M0, M3, M4, M5, T_10, H0_1, H1_1, H2_1, T_9)
// H += m1
VZERO V0
VZERO T_1
VZERO T_2
VZERO T_3
EXPACC2(M1, T_1, T_2, T_3, T_4, T_5, T_6)
VLEIB $10, $1, T_3
VAQ H0_0, T_1, H0_0
VAQ H1_0, T_2, H1_0
VAQ H2_0, T_3, H2_0
REDUCE2(H0_0, H1_0, H2_0, M0, M3, M4, M5, T_9, H0_1, H1_1, H2_1, T_10)
// [H, m2] * [r**2, r]
EXPACC2(M2, H0_0, H1_0, H2_0, T_1, T_2, T_3)
CMPBNE R3, $16, 2(PC)
VLEIB $10, $1, H2_0
VZERO M0
VZERO M1
VZERO M2
VZERO M3
VZERO M4
VZERO M5
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, H0_1, H1_1, M5, T_10)
SUB $16, R3
CMPBLE R3, $0, next // this condition must always hold true!
b2:
CMPBLE R3, $16, b1
// 2 blocks remaining
// setup [r²,r]
VSLDB $8, R_0, R_0, R_0
VSLDB $8, R_1, R_1, R_1
VSLDB $8, R_2, R_2, R_2
VSLDB $8, R5_1, R5_1, R5_1
VSLDB $8, R5_2, R5_2, R5_2
VLVGG $1, RSAVE_0, R_0
VLVGG $1, RSAVE_1, R_1
VLVGG $1, RSAVE_2, R_2
VLVGG $1, R5SAVE_1, R5_1
VLVGG $1, R5SAVE_2, R5_2
// setup [h0, h1]
VSLDB $8, H0_0, H0_0, H0_0
VSLDB $8, H1_0, H1_0, H1_0
VSLDB $8, H2_0, H2_0, H2_0
VO H0_1, H0_0, H0_0
VO H1_1, H1_0, H1_0
VO H2_1, H2_0, H2_0
VZERO H0_1
VZERO H1_1
VZERO H2_1
VZERO M0
VZERO M1
VZERO M2
VZERO M3
VZERO M4
VZERO M5
// H*[r**2, r]
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, T_10, M0, M1, M2, M3, M4, T_4, T_5, T_2, T_7, T_8, T_9)
VMRHG V0, H0_1, H0_0
VMRHG V0, H1_1, H1_0
VMRHG V0, H2_1, H2_0
VMRLG V0, H0_1, H0_1
VMRLG V0, H1_1, H1_1
VMRLG V0, H2_1, H2_1
// move h to the left and 0s at the right
VSLDB $8, H0_0, H0_0, H0_0
VSLDB $8, H1_0, H1_0, H1_0
VSLDB $8, H2_0, H2_0, H2_0
// get message blocks and append 1 to start
SUB $17, R3
VL (R2), M0
VLL R3, 16(R2), M1
ADD $1, R3
MOVBZ $1, R0
CMPBEQ R3, $16, 2(PC)
VLVGB R3, R0, M1
VZERO T_6
VZERO T_7
VZERO T_8
EXPACC2(M0, T_6, T_7, T_8, T_1, T_2, T_3)
EXPACC2(M1, T_6, T_7, T_8, T_1, T_2, T_3)
VLEIB $2, $1, T_8
CMPBNE R3, $16, 2(PC)
VLEIB $10, $1, T_8
// add [m0, m1] to h
VAG H0_0, T_6, H0_0
VAG H1_0, T_7, H1_0
VAG H2_0, T_8, H2_0
VZERO M2
VZERO M3
VZERO M4
VZERO M5
VZERO T_10
VZERO M0
// at this point R_0 .. R5_2 look like [r**2, r]
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M2, M3, M4, M5, T_10, M0, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE2(H0_0, H1_0, H2_0, M2, M3, M4, M5, T_9, H0_1, H1_1, H2_1, T_10)
SUB $16, R3, R3
CMPBLE R3, $0, next
b1:
CMPBLE R3, $0, next
// 1 block remaining
// setup [r²,r]
VSLDB $8, R_0, R_0, R_0
VSLDB $8, R_1, R_1, R_1
VSLDB $8, R_2, R_2, R_2
VSLDB $8, R5_1, R5_1, R5_1
VSLDB $8, R5_2, R5_2, R5_2
VLVGG $1, RSAVE_0, R_0
VLVGG $1, RSAVE_1, R_1
VLVGG $1, RSAVE_2, R_2
VLVGG $1, R5SAVE_1, R5_1
VLVGG $1, R5SAVE_2, R5_2
// setup [h0, h1]
VSLDB $8, H0_0, H0_0, H0_0
VSLDB $8, H1_0, H1_0, H1_0
VSLDB $8, H2_0, H2_0, H2_0
VO H0_1, H0_0, H0_0
VO H1_1, H1_0, H1_0
VO H2_1, H2_0, H2_0
VZERO H0_1
VZERO H1_1
VZERO H2_1
VZERO M0
VZERO M1
VZERO M2
VZERO M3
VZERO M4
VZERO M5
// H*[r**2, r]
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, T_9, T_10, H0_1, M5)
// set up [0, m0] limbs
SUB $1, R3
VLL R3, (R2), M0
ADD $1, R3
MOVBZ $1, R0
CMPBEQ R3, $16, 2(PC)
VLVGB R3, R0, M0
VZERO T_1
VZERO T_2
VZERO T_3
EXPACC2(M0, T_1, T_2, T_3, T_4, T_5, T_6)// limbs: [0, m]
CMPBNE R3, $16, 2(PC)
VLEIB $10, $1, T_3
// h+m0
VAQ H0_0, T_1, H0_0
VAQ H1_0, T_2, H1_0
VAQ H2_0, T_3, H2_0
VZERO M0
VZERO M1
VZERO M2
VZERO M3
VZERO M4
VZERO M5
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, T_9, T_10, H0_1, M5)
BR next
square:
// setup [r²,r]
VSLDB $8, R_0, R_0, R_0
VSLDB $8, R_1, R_1, R_1
VSLDB $8, R_2, R_2, R_2
VSLDB $8, R5_1, R5_1, R5_1
VSLDB $8, R5_2, R5_2, R5_2
VLVGG $1, RSAVE_0, R_0
VLVGG $1, RSAVE_1, R_1
VLVGG $1, RSAVE_2, R_2
VLVGG $1, R5SAVE_1, R5_1
VLVGG $1, R5SAVE_2, R5_2
// setup [h0, h1]
VSLDB $8, H0_0, H0_0, H0_0
VSLDB $8, H1_0, H1_0, H1_0
VSLDB $8, H2_0, H2_0, H2_0
VO H0_1, H0_0, H0_0
VO H1_1, H1_0, H1_0
VO H2_1, H2_0, H2_0
VZERO H0_1
VZERO H1_1
VZERO H2_1
VZERO M0
VZERO M1
VZERO M2
VZERO M3
VZERO M4
VZERO M5
// (h0*r**2) + (h1*r)
MULTIPLY(H0_0, H1_0, H2_0, H0_1, H1_1, H2_1, R_0, R_1, R_2, R5_1, R5_2, M0, M1, M2, M3, M4, M5, T_0, T_1, T_2, T_3, T_4, T_5, T_6, T_7, T_8, T_9)
REDUCE2(H0_0, H1_0, H2_0, M0, M1, M2, M3, M4, T_9, T_10, H0_1, M5)
BR next
TEXT ·hasVMSLFacility(SB), NOSPLIT, $24-1
MOVD $x-24(SP), R1
XC $24, 0(R1), 0(R1) // clear the storage
MOVD $2, R0 // R0 is the number of double words stored -1
WORD $0xB2B01000 // STFLE 0(R1)
XOR R0, R0 // reset the value of R0
MOVBZ z-8(SP), R1
AND $0x01, R1
BEQ novmsl
vectorinstalled:
// check if the vector instruction has been enabled
VLEIB $0, $0xF, V16
VLGVB $0, V16, R1
CMPBNE R1, $0xF, novmsl
MOVB $1, ret+0(FP) // have vx
RET
novmsl:
MOVB $0, ret+0(FP) // no vx
RET

View File

@ -222,6 +222,11 @@ type openSSHCertSigner struct {
signer Signer signer Signer
} }
type algorithmOpenSSHCertSigner struct {
*openSSHCertSigner
algorithmSigner AlgorithmSigner
}
// NewCertSigner returns a Signer that signs with the given Certificate, whose // NewCertSigner returns a Signer that signs with the given Certificate, whose
// private key is held by signer. It returns an error if the public key in cert // private key is held by signer. It returns an error if the public key in cert
// doesn't match the key used by signer. // doesn't match the key used by signer.
@ -230,8 +235,13 @@ func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
return nil, errors.New("ssh: signer and cert have different public key") return nil, errors.New("ssh: signer and cert have different public key")
} }
if algorithmSigner, ok := signer.(AlgorithmSigner); ok {
return &algorithmOpenSSHCertSigner{
&openSSHCertSigner{cert, signer}, algorithmSigner}, nil
} else {
return &openSSHCertSigner{cert, signer}, nil return &openSSHCertSigner{cert, signer}, nil
} }
}
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
return s.signer.Sign(rand, data) return s.signer.Sign(rand, data)
@ -241,6 +251,10 @@ func (s *openSSHCertSigner) PublicKey() PublicKey {
return s.pub return s.pub
} }
func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
}
const sourceAddressCriticalOption = "source-address" const sourceAddressCriticalOption = "source-address"
// CertChecker does the work of verifying a certificate. Its methods // CertChecker does the work of verifying a certificate. Its methods

View File

@ -16,6 +16,7 @@ import (
"hash" "hash"
"io" "io"
"io/ioutil" "io/ioutil"
"math/bits"
"golang.org/x/crypto/internal/chacha20" "golang.org/x/crypto/internal/chacha20"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
@ -641,8 +642,8 @@ const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// the methods here also implement padding, which RFC4253 Section 6 // the methods here also implement padding, which RFC4253 Section 6
// also requires of stream ciphers. // also requires of stream ciphers.
type chacha20Poly1305Cipher struct { type chacha20Poly1305Cipher struct {
lengthKey [32]byte lengthKey [8]uint32
contentKey [32]byte contentKey [8]uint32
buf []byte buf []byte
} }
@ -655,20 +656,21 @@ func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionA
buf: make([]byte, 256), buf: make([]byte, 256),
} }
copy(c.contentKey[:], key[:32]) for i := range c.contentKey {
copy(c.lengthKey[:], key[32:]) c.contentKey[i] = binary.LittleEndian.Uint32(key[i*4 : (i+1)*4])
}
for i := range c.lengthKey {
c.lengthKey[i] = binary.LittleEndian.Uint32(key[(i+8)*4 : (i+9)*4])
}
return c, nil return c, nil
} }
// The Poly1305 key is obtained by encrypting 32 0-bytes.
var chacha20PolyKeyInput [32]byte
func (c *chacha20Poly1305Cipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { func (c *chacha20Poly1305Cipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
var counter [16]byte nonce := [3]uint32{0, 0, bits.ReverseBytes32(seqNum)}
binary.BigEndian.PutUint64(counter[8:], uint64(seqNum)) s := chacha20.New(c.contentKey, nonce)
var polyKey [32]byte var polyKey [32]byte
chacha20.XORKeyStream(polyKey[:], chacha20PolyKeyInput[:], &counter, &c.contentKey) s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip next 32 bytes
encryptedLength := c.buf[:4] encryptedLength := c.buf[:4]
if _, err := io.ReadFull(r, encryptedLength); err != nil { if _, err := io.ReadFull(r, encryptedLength); err != nil {
@ -676,7 +678,7 @@ func (c *chacha20Poly1305Cipher) readPacket(seqNum uint32, r io.Reader) ([]byte,
} }
var lenBytes [4]byte var lenBytes [4]byte
chacha20.XORKeyStream(lenBytes[:], encryptedLength, &counter, &c.lengthKey) chacha20.New(c.lengthKey, nonce).XORKeyStream(lenBytes[:], encryptedLength)
length := binary.BigEndian.Uint32(lenBytes[:]) length := binary.BigEndian.Uint32(lenBytes[:])
if length > maxPacket { if length > maxPacket {
@ -702,10 +704,8 @@ func (c *chacha20Poly1305Cipher) readPacket(seqNum uint32, r io.Reader) ([]byte,
return nil, errors.New("ssh: MAC failure") return nil, errors.New("ssh: MAC failure")
} }
counter[0] = 1
plain := c.buf[4:contentEnd] plain := c.buf[4:contentEnd]
chacha20.XORKeyStream(plain, plain, &counter, &c.contentKey) s.XORKeyStream(plain, plain)
padding := plain[0] padding := plain[0]
if padding < 4 { if padding < 4 {
@ -724,11 +724,11 @@ func (c *chacha20Poly1305Cipher) readPacket(seqNum uint32, r io.Reader) ([]byte,
} }
func (c *chacha20Poly1305Cipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error { func (c *chacha20Poly1305Cipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error {
var counter [16]byte nonce := [3]uint32{0, 0, bits.ReverseBytes32(seqNum)}
binary.BigEndian.PutUint64(counter[8:], uint64(seqNum)) s := chacha20.New(c.contentKey, nonce)
var polyKey [32]byte var polyKey [32]byte
chacha20.XORKeyStream(polyKey[:], chacha20PolyKeyInput[:], &counter, &c.contentKey) s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip next 32 bytes
// There is no blocksize, so fall back to multiple of 8 byte // There is no blocksize, so fall back to multiple of 8 byte
// padding, as described in RFC 4253, Sec 6. // padding, as described in RFC 4253, Sec 6.
@ -748,7 +748,7 @@ func (c *chacha20Poly1305Cipher) writePacket(seqNum uint32, w io.Writer, rand io
} }
binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding)) binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding))
chacha20.XORKeyStream(c.buf, c.buf[:4], &counter, &c.lengthKey) chacha20.New(c.lengthKey, nonce).XORKeyStream(c.buf, c.buf[:4])
c.buf[4] = byte(padding) c.buf[4] = byte(padding)
copy(c.buf[5:], payload) copy(c.buf[5:], payload)
packetEnd := 5 + len(payload) + padding packetEnd := 5 + len(payload) + padding
@ -756,8 +756,7 @@ func (c *chacha20Poly1305Cipher) writePacket(seqNum uint32, w io.Writer, rand io
return err return err
} }
counter[0] = 1 s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd])
chacha20.XORKeyStream(c.buf[4:], c.buf[4:packetEnd], &counter, &c.contentKey)
var mac [poly1305.TagSize]byte var mac [poly1305.TagSize]byte
poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey) poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey)

View File

@ -19,6 +19,8 @@ import (
type Client struct { type Client struct {
Conn Conn
handleForwardsOnce sync.Once // guards calling (*Client).handleForwards
forwards forwardList // forwarded tcpip connections from the remote side forwards forwardList // forwarded tcpip connections from the remote side
mu sync.Mutex mu sync.Mutex
channelHandlers map[string]chan NewChannel channelHandlers map[string]chan NewChannel
@ -60,8 +62,6 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn.Wait() conn.Wait()
conn.forwards.closeAll() conn.forwards.closeAll()
}() }()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
return conn return conn
} }
@ -185,7 +185,7 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
// keys. A HostKeyCallback must return nil if the host key is OK, or // keys. A HostKeyCallback must return nil if the host key is OK, or
// an error to reject it. It receives the hostname as passed to Dial // an error to reject it. It receives the hostname as passed to Dial
// or NewClientConn. The remote address is the RemoteAddr of the // or NewClientConn. The remote address is the RemoteAddr of the
// net.Conn underlying the the SSH connection. // net.Conn underlying the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// BannerCallback is the function type used for treat the banner sent by // BannerCallback is the function type used for treat the banner sent by

View File

@ -38,6 +38,16 @@ const (
KeyAlgoED25519 = "ssh-ed25519" KeyAlgoED25519 = "ssh-ed25519"
) )
// These constants represent non-default signature algorithms that are supported
// as algorithm parameters to AlgorithmSigner.SignWithAlgorithm methods. See
// [PROTOCOL.agent] section 4.5.1 and
// https://tools.ietf.org/html/draft-ietf-curdle-rsa-sha2-10
const (
SigAlgoRSA = "ssh-rsa"
SigAlgoRSASHA2256 = "rsa-sha2-256"
SigAlgoRSASHA2512 = "rsa-sha2-512"
)
// parsePubKey parses a public key of the given algorithm. // parsePubKey parses a public key of the given algorithm.
// Use ParsePublicKey for keys with prepended algorithm. // Use ParsePublicKey for keys with prepended algorithm.
func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) {
@ -301,6 +311,19 @@ type Signer interface {
Sign(rand io.Reader, data []byte) (*Signature, error) Sign(rand io.Reader, data []byte) (*Signature, error)
} }
// A AlgorithmSigner is a Signer that also supports specifying a specific
// algorithm to use for signing.
type AlgorithmSigner interface {
Signer
// SignWithAlgorithm is like Signer.Sign, but allows specification of a
// non-default signing algorithm. See the SigAlgo* constants in this
// package for signature algorithms supported by this package. Callers may
// pass an empty string for the algorithm in which case the AlgorithmSigner
// will use its default algorithm.
SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error)
}
type rsaPublicKey rsa.PublicKey type rsaPublicKey rsa.PublicKey
func (r *rsaPublicKey) Type() string { func (r *rsaPublicKey) Type() string {
@ -349,13 +372,21 @@ func (r *rsaPublicKey) Marshal() []byte {
} }
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != r.Type() { var hash crypto.Hash
switch sig.Format {
case SigAlgoRSA:
hash = crypto.SHA1
case SigAlgoRSASHA2256:
hash = crypto.SHA256
case SigAlgoRSASHA2512:
hash = crypto.SHA512
default:
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
} }
h := crypto.SHA1.New() h := hash.New()
h.Write(data) h.Write(data)
digest := h.Sum(nil) digest := h.Sum(nil)
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, sig.Blob)
} }
func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
@ -459,6 +490,14 @@ func (k *dsaPrivateKey) PublicKey() PublicKey {
} }
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
return k.SignWithAlgorithm(rand, data, "")
}
func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if algorithm != "" && algorithm != k.PublicKey().Type() {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
}
h := crypto.SHA1.New() h := crypto.SHA1.New()
h.Write(data) h.Write(data)
digest := h.Sum(nil) digest := h.Sum(nil)
@ -691,10 +730,35 @@ func (s *wrappedSigner) PublicKey() PublicKey {
} }
func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
return s.SignWithAlgorithm(rand, data, "")
}
func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
var hashFunc crypto.Hash var hashFunc crypto.Hash
if _, ok := s.pubKey.(*rsaPublicKey); ok {
// RSA keys support a few hash functions determined by the requested signature algorithm
switch algorithm {
case "", SigAlgoRSA:
algorithm = SigAlgoRSA
hashFunc = crypto.SHA1
case SigAlgoRSASHA2256:
hashFunc = crypto.SHA256
case SigAlgoRSASHA2512:
hashFunc = crypto.SHA512
default:
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
}
} else {
// The only supported algorithm for all other key types is the same as the type of the key
if algorithm == "" {
algorithm = s.pubKey.Type()
} else if algorithm != s.pubKey.Type() {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
}
switch key := s.pubKey.(type) { switch key := s.pubKey.(type) {
case *rsaPublicKey, *dsaPublicKey: case *dsaPublicKey:
hashFunc = crypto.SHA1 hashFunc = crypto.SHA1
case *ecdsaPublicKey: case *ecdsaPublicKey:
hashFunc = ecHash(key.Curve) hashFunc = ecHash(key.Curve)
@ -702,6 +766,7 @@ func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
default: default:
return nil, fmt.Errorf("ssh: unsupported key type %T", key) return nil, fmt.Errorf("ssh: unsupported key type %T", key)
} }
}
var digest []byte var digest []byte
if hashFunc != 0 { if hashFunc != 0 {
@ -745,7 +810,7 @@ func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
} }
return &Signature{ return &Signature{
Format: s.pubKey.Type(), Format: algorithm,
Blob: signature, Blob: signature,
}, nil }, nil
} }
@ -803,7 +868,7 @@ func encryptedBlock(block *pem.Block) bool {
} }
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It // ParseRawPrivateKey returns a private key from a PEM encoded private key. It
// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys. // supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys.
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
block, _ := pem.Decode(pemBytes) block, _ := pem.Decode(pemBytes)
if block == nil { if block == nil {
@ -817,6 +882,9 @@ func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
switch block.Type { switch block.Type {
case "RSA PRIVATE KEY": case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes) return x509.ParsePKCS1PrivateKey(block.Bytes)
// RFC5208 - https://tools.ietf.org/html/rfc5208
case "PRIVATE KEY":
return x509.ParsePKCS8PrivateKey(block.Bytes)
case "EC PRIVATE KEY": case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes) return x509.ParseECPrivateKey(block.Bytes)
case "DSA PRIVATE KEY": case "DSA PRIVATE KEY":
@ -900,8 +968,8 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
// Implemented based on the documentation at // Implemented based on the documentation at
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) { func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) {
magic := append([]byte("openssh-key-v1"), 0) const magic = "openssh-key-v1\x00"
if !bytes.Equal(magic, key[0:len(magic)]) { if len(key) < len(magic) || string(key[:len(magic)]) != magic {
return nil, errors.New("ssh: invalid openssh private key format") return nil, errors.New("ssh: invalid openssh private key format")
} }
remaining := key[len(magic):] remaining := key[len(magic):]

View File

@ -404,7 +404,7 @@ userAuthLoop:
perms, authErr = config.PasswordCallback(s, password) perms, authErr = config.PasswordCallback(s, password)
case "keyboard-interactive": case "keyboard-interactive":
if config.KeyboardInteractiveCallback == nil { if config.KeyboardInteractiveCallback == nil {
authErr = errors.New("ssh: keyboard-interactive auth not configubred") authErr = errors.New("ssh: keyboard-interactive auth not configured")
break break
} }

View File

@ -32,6 +32,7 @@ type streamLocalChannelForwardMsg struct {
// ListenUnix is similar to ListenTCP but uses a Unix domain socket. // ListenUnix is similar to ListenTCP but uses a Unix domain socket.
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) { func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
m := streamLocalChannelForwardMsg{ m := streamLocalChannelForwardMsg{
socketPath, socketPath,
} }

View File

@ -90,10 +90,19 @@ type channelForwardMsg struct {
rport uint32 rport uint32
} }
// handleForwards starts goroutines handling forwarded connections.
// It's called on first use by (*Client).ListenTCP to not launch
// goroutines until needed.
func (c *Client) handleForwards() {
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
}
// ListenTCP requests the remote peer open a listening socket // ListenTCP requests the remote peer open a listening socket
// on laddr. Incoming connections will be available by calling // on laddr. Incoming connections will be available by calling
// Accept on the returned net.Listener. // Accept on the returned net.Listener.
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
return c.autoPortListenWorkaround(laddr) return c.autoPortListenWorkaround(laddr)
} }

View File

@ -108,9 +108,7 @@ func ReadPassword(fd int) ([]byte, error) {
return nil, err return nil, err
} }
defer func() { defer unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
}()
return readPasswordLine(passwordReader(fd)) return readPasswordLine(passwordReader(fd))
} }

View File

@ -14,7 +14,7 @@ import (
// State contains the state of a terminal. // State contains the state of a terminal.
type State struct { type State struct {
state *unix.Termios termios unix.Termios
} }
// IsTerminal returns true if the given file descriptor is a terminal. // IsTerminal returns true if the given file descriptor is a terminal.
@ -75,47 +75,43 @@ func ReadPassword(fd int) ([]byte, error) {
// restored. // restored.
// see http://cr.illumos.org/~webrev/andy_js/1060/ // see http://cr.illumos.org/~webrev/andy_js/1060/
func MakeRaw(fd int) (*State, error) { func MakeRaw(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS) termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil { if err != nil {
return nil, err return nil, err
} }
oldTermios := *oldTermiosPtr
newTermios := oldTermios oldState := State{termios: *termios}
newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
newTermios.Oflag &^= syscall.OPOST
newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
newTermios.Cflag |= syscall.CS8
newTermios.Cc[unix.VMIN] = 1
newTermios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil { termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, termios); err != nil {
return nil, err return nil, err
} }
return &State{ return &oldState, nil
state: oldTermiosPtr,
}, nil
} }
// Restore restores the terminal connected to the given file descriptor to a // Restore restores the terminal connected to the given file descriptor to a
// previous state. // previous state.
func Restore(fd int, oldState *State) error { func Restore(fd int, oldState *State) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state) return unix.IoctlSetTermios(fd, unix.TCSETS, &oldState.termios)
} }
// GetState returns the current state of a terminal which may be useful to // GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal. // restore the terminal after a signal.
func GetState(fd int) (*State, error) { func GetState(fd int) (*State, error) {
oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS) termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &State{ return &State{termios: *termios}, nil
state: oldTermiosPtr,
}, nil
} }
// GetSize returns the dimensions of the given terminal. // GetSize returns the dimensions of the given terminal.

View File

@ -89,9 +89,7 @@ func ReadPassword(fd int) ([]byte, error) {
return nil, err return nil, err
} }
defer func() { defer windows.SetConsoleMode(windows.Handle(fd), old)
windows.SetConsoleMode(windows.Handle(fd), old)
}()
var h windows.Handle var h windows.Handle
p, _ := windows.GetCurrentProcess() p, _ := windows.GetCurrentProcess()

13
vendor/modules.txt vendored
View File

@ -15,6 +15,8 @@ github.com/Philipp15b/go-steam/rwu
github.com/Philipp15b/go-steam/socialcache github.com/Philipp15b/go-steam/socialcache
# github.com/alecthomas/log4go v0.0.0-20160307011253-e5dc62318d9b # github.com/alecthomas/log4go v0.0.0-20160307011253-e5dc62318d9b
github.com/alecthomas/log4go github.com/alecthomas/log4go
# github.com/bwmarrin/discordgo v0.19.0
github.com/bwmarrin/discordgo
# github.com/davecgh/go-spew v1.1.0 # github.com/davecgh/go-spew v1.1.0
github.com/davecgh/go-spew/spew github.com/davecgh/go-spew/spew
# github.com/dfordsoft/golib v0.0.0-20180313113957-2ea3495aee1d # github.com/dfordsoft/golib v0.0.0-20180313113957-2ea3495aee1d
@ -34,7 +36,7 @@ github.com/google/gops/internal
github.com/google/gops/signal github.com/google/gops/signal
# github.com/gorilla/schema v0.0.0-20170317173100-f3c80893412c # github.com/gorilla/schema v0.0.0-20170317173100-f3c80893412c
github.com/gorilla/schema github.com/gorilla/schema
# github.com/gorilla/websocket v0.0.0-20170319172727-a91eba7f9777 # github.com/gorilla/websocket v1.4.0
github.com/gorilla/websocket github.com/gorilla/websocket
# github.com/hashicorp/golang-lru v0.0.0-20160813221303-0a025b7e63ad # github.com/hashicorp/golang-lru v0.0.0-20160813221303-0a025b7e63ad
github.com/hashicorp/golang-lru github.com/hashicorp/golang-lru
@ -66,8 +68,6 @@ github.com/labstack/gommon/random
github.com/lrstanley/girc github.com/lrstanley/girc
# github.com/magiconair/properties v0.0.0-20180217134545-2c9e95027885 # github.com/magiconair/properties v0.0.0-20180217134545-2c9e95027885
github.com/magiconair/properties github.com/magiconair/properties
# github.com/matterbridge/discordgo v0.0.0-20180806170629-ef40ff5ba64f
github.com/matterbridge/discordgo
# github.com/matterbridge/go-xmpp v0.0.0-20180529212104-cd19799fba91 # github.com/matterbridge/go-xmpp v0.0.0-20180529212104-cd19799fba91
github.com/matterbridge/go-xmpp github.com/matterbridge/go-xmpp
# github.com/matterbridge/gomatrix v0.0.0-20171224233421-78ac6a1a0f5f # github.com/matterbridge/gomatrix v0.0.0-20171224233421-78ac6a1a0f5f
@ -146,19 +146,20 @@ github.com/valyala/bytebufferpool
github.com/valyala/fasttemplate github.com/valyala/fasttemplate
# github.com/zfjagann/golang-ring v0.0.0-20141111230621-17637388c9f6 # github.com/zfjagann/golang-ring v0.0.0-20141111230621-17637388c9f6
github.com/zfjagann/golang-ring github.com/zfjagann/golang-ring
# golang.org/x/crypto v0.0.0-20180228161326-91a49db82a88 # golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16
golang.org/x/crypto/ssh/terminal golang.org/x/crypto/ssh/terminal
golang.org/x/crypto/acme/autocert golang.org/x/crypto/acme/autocert
golang.org/x/crypto/nacl/secretbox golang.org/x/crypto/nacl/secretbox
golang.org/x/crypto/ssh
golang.org/x/crypto/bcrypt golang.org/x/crypto/bcrypt
golang.org/x/crypto/ssh
golang.org/x/crypto/acme golang.org/x/crypto/acme
golang.org/x/crypto/internal/subtle
golang.org/x/crypto/poly1305 golang.org/x/crypto/poly1305
golang.org/x/crypto/salsa20/salsa golang.org/x/crypto/salsa20/salsa
golang.org/x/crypto/blowfish
golang.org/x/crypto/curve25519 golang.org/x/crypto/curve25519
golang.org/x/crypto/ed25519 golang.org/x/crypto/ed25519
golang.org/x/crypto/internal/chacha20 golang.org/x/crypto/internal/chacha20
golang.org/x/crypto/blowfish
golang.org/x/crypto/ed25519/internal/edwards25519 golang.org/x/crypto/ed25519/internal/edwards25519
# golang.org/x/sys v0.0.0-20171130163741-8b4580aae2a0 # golang.org/x/sys v0.0.0-20171130163741-8b4580aae2a0
golang.org/x/sys/unix golang.org/x/sys/unix