diff --git a/go.mod b/go.mod index d8ce01b..dde017b 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/xmppo/go-xmpp go 1.21.5 -require golang.org/x/crypto v0.18.0 +require ( + golang.org/x/crypto v0.18.0 + golang.org/x/net v0.10.0 +) diff --git a/go.sum b/go.sum index 4dda6b7..959da46 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= diff --git a/xmpp.go b/xmpp.go index c009c7f..4dd161a 100644 --- a/xmpp.go +++ b/xmpp.go @@ -41,6 +41,7 @@ import ( "time" "golang.org/x/crypto/pbkdf2" + "golang.org/x/net/proxy" ) const ( @@ -103,12 +104,12 @@ func connect(host, user, passwd string, timeout time.Duration) (net.Conn, error) addr += ":5222" } - proxy := os.Getenv("HTTP_PROXY") - if proxy == "" { - proxy = os.Getenv("http_proxy") + http_proxy := os.Getenv("HTTP_PROXY") + if http_proxy == "" { + http_proxy = os.Getenv("http_proxy") } // test for no proxy, takes a comma separated list with substrings to match - if proxy != "" { + if http_proxy != "" { noproxy := os.Getenv("NO_PROXY") if noproxy == "" { noproxy = os.Getenv("no_proxy") @@ -117,25 +118,38 @@ func connect(host, user, passwd string, timeout time.Duration) (net.Conn, error) nplist := strings.Split(noproxy, ",") for _, s := range nplist { if containsIgnoreCase(addr, s) { - proxy = "" + http_proxy = "" break } } } } - if proxy != "" { - url, err := url.Parse(proxy) + socks5Target, socks5 := strings.CutPrefix(http_proxy, "socks5://") + if http_proxy != "" && !socks5 { + url, err := url.Parse(http_proxy) if err == nil { addr = url.Host } } - - c, err := net.DialTimeout("tcp", addr, timeout) - if err != nil { - return nil, err + var c net.Conn + var err error + if socks5 { + dialer, err := proxy.SOCKS5("tcp", socks5Target, nil, nil) + if err != nil { + return nil, err + } + c, err = dialer.Dial("tcp", addr) + if err != nil { + return nil, err + } + } else { + c, err = net.DialTimeout("tcp", addr, timeout) + if err != nil { + return nil, err + } } - if proxy != "" { + if http_proxy != "" && !socks5 { fmt.Fprintf(c, "CONNECT %s HTTP/1.1\r\n", host) fmt.Fprintf(c, "Host: %s\r\n", host) fmt.Fprintf(c, "\r\n")