use Stdin.

This commit is contained in:
mattn 2011-11-04 22:40:10 +09:00
parent b6a67e2320
commit 0a0f20b95e
2 changed files with 57 additions and 40 deletions

View File

@ -1,9 +1,9 @@
package main package main
import ( import (
"bufio"
"fmt" "fmt"
"flag" "flag"
"github.com/kless/go-readin/readin"
"github.com/mattn/go-xmpp" "github.com/mattn/go-xmpp"
"github.com/mattn/go-iconv" "github.com/mattn/go-iconv"
"log" "log"
@ -61,11 +61,12 @@ func main() {
} }
}() }()
for { for {
line, err := readin.RepeatPrompt("") in := bufio.NewReader(os.Stdin)
line, err := in.ReadString('\n')
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err.String())
continue continue
} }
line = strings.TrimRight(line, "\n")
tokens := strings.SplitN(line, " ", 2) tokens := strings.SplitN(line, " ", 2)
if len(tokens) == 2 { if len(tokens) == 2 {

90
xmpp.go
View File

@ -17,6 +17,7 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"http" "http"
"io" "io"
@ -24,8 +25,8 @@ import (
"net" "net"
"os" "os"
"strings" "strings"
"xml"
"url" "url"
"xml"
) )
const ( const (
@ -47,7 +48,7 @@ type Client struct {
// NewClient creates a new connection to a host given as "hostname" or "hostname:port". // NewClient creates a new connection to a host given as "hostname" or "hostname:port".
// If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID. // If host is not specified, the DNS SRV should be used to find the host from the domainpart of the JID.
// Default the port to 5222. // Default the port to 5222.
func NewClient(host, user, passwd string) (*Client, os.Error) { func NewClient(host, user, passwd string) (*Client, error) {
addr := host addr := host
if strings.TrimSpace(host) == "" { if strings.TrimSpace(host) == "" {
@ -87,7 +88,7 @@ func NewClient(host, user, passwd string) (*Client, os.Error) {
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2) f := strings.SplitN(resp.Status, " ", 2)
return nil, os.NewError(f[1]) return nil, errors.New(f[1])
} }
} }
@ -112,18 +113,18 @@ func NewClient(host, user, passwd string) (*Client, os.Error) {
return client, nil return client, nil
} }
func (c *Client) Close() os.Error { func (c *Client) Close() error {
return c.tls.Close() return c.tls.Close()
} }
func (c *Client) init(user, passwd string) os.Error { func (c *Client) init(user, passwd string) error {
// For debugging: the following causes the plaintext of the connection to be duplicated to stdout. // For debugging: the following causes the plaintext of the connection to be duplicated to stdout.
// c.p = xml.NewParser(tee{c.tls, os.Stdout}); // c.p = xml.NewParser(tee{c.tls, os.Stdout});
c.p = xml.NewParser(c.tls) c.p = xml.NewParser(c.tls)
a := strings.SplitN(user, "@", 2) a := strings.SplitN(user, "@", 2)
if len(a) != 2 { if len(a) != 2 {
return os.NewError("xmpp: invalid username (want user@domain): " + user) return errors.New("xmpp: invalid username (want user@domain): " + user)
} }
user = a[0] user = a[0]
domain := a[1] domain := a[1]
@ -140,7 +141,7 @@ func (c *Client) init(user, passwd string) os.Error {
return err return err
} }
if se.Name.Space != nsStream || se.Name.Local != "stream" { if se.Name.Space != nsStream || se.Name.Local != "stream" {
return os.NewError("xmpp: expected <stream> but got <" + se.Name.Local + "> in " + se.Name.Space) return errors.New("xmpp: expected <stream> but got <" + se.Name.Local + "> in " + se.Name.Space)
} }
// Now we're in the stream and can use Unmarshal. // Now we're in the stream and can use Unmarshal.
@ -148,7 +149,7 @@ func (c *Client) init(user, passwd string) os.Error {
// See section 4.6 in RFC 3920. // See section 4.6 in RFC 3920.
var f streamFeatures var f streamFeatures
if err = c.p.Unmarshal(&f, nil); err != nil { if err = c.p.Unmarshal(&f, nil); err != nil {
return os.NewError("unmarshal <features>: " + err.String()) return errors.New("unmarshal <features>: " + err.Error())
} }
havePlain := false havePlain := false
for _, m := range f.Mechanisms.Mechanism { for _, m := range f.Mechanisms.Mechanism {
@ -158,7 +159,7 @@ func (c *Client) init(user, passwd string) os.Error {
} }
} }
if !havePlain { if !havePlain {
return os.NewError(fmt.Sprintf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism)) return errors.New(fmt.Sprintf("PLAIN authentication is not an option: %v", f.Mechanisms.Mechanism))
} }
// Plain authentication: send base64-encoded \x00 user \x00 password. // Plain authentication: send base64-encoded \x00 user \x00 password.
@ -175,9 +176,9 @@ func (c *Client) init(user, passwd string) os.Error {
case *saslFailure: case *saslFailure:
// v.Any is type of sub-element in failure, // v.Any is type of sub-element in failure,
// which gives a description of what failed. // which gives a description of what failed.
return os.NewError("auth failure: " + v.Any.Local) return errors.New("auth failure: " + v.Any.Local)
default: default:
return os.NewError("expected <success> or <failure>, got <" + name.Local + "> in " + name.Space) return errors.New("expected <success> or <failure>, got <" + name.Local + "> in " + name.Space)
} }
// Now that we're authenticated, we're supposed to start the stream over again. // Now that we're authenticated, we're supposed to start the stream over again.
@ -192,7 +193,7 @@ func (c *Client) init(user, passwd string) os.Error {
return err return err
} }
if se.Name.Space != nsStream || se.Name.Local != "stream" { if se.Name.Space != nsStream || se.Name.Local != "stream" {
return os.NewError("expected <stream>, got <" + se.Name.Local + "> in " + se.Name.Space) return errors.New("expected <stream>, got <" + se.Name.Local + "> in " + se.Name.Space)
} }
if err = c.p.Unmarshal(&f, nil); err != nil { if err = c.p.Unmarshal(&f, nil); err != nil {
// TODO: often stream stop. // TODO: often stream stop.
@ -203,10 +204,10 @@ func (c *Client) init(user, passwd string) os.Error {
fmt.Fprintf(c.tls, "<iq type='set' id='x'><bind xmlns='%s'/></iq>\n", nsBind) fmt.Fprintf(c.tls, "<iq type='set' id='x'><bind xmlns='%s'/></iq>\n", nsBind)
var iq clientIQ var iq clientIQ
if err = c.p.Unmarshal(&iq, nil); err != nil { if err = c.p.Unmarshal(&iq, nil); err != nil {
return os.NewError("unmarshal <iq>: " + err.String()) return errors.New("unmarshal <iq>: " + err.Error())
} }
if &iq.Bind == nil { if &iq.Bind == nil {
return os.NewError("<iq> result missing <bind>") return errors.New("<iq> result missing <bind>")
} }
c.jid = iq.Bind.Jid // our local id c.jid = iq.Bind.Jid // our local id
@ -222,7 +223,7 @@ type Chat struct {
} }
// Recv wait next token of chat. // Recv wait next token of chat.
func (c *Client) Recv() (chat Chat, err os.Error) { func (c *Client) Recv() (chat Chat, err error) {
for { for {
_, val, err := next(c.p) _, val, err := next(c.p)
if err != nil { if err != nil {
@ -243,7 +244,6 @@ func (c *Client) Send(chat Chat) {
xmlEscape(chat.Text)) xmlEscape(chat.Text))
} }
// RFC 3920 C.1 Streams name space // RFC 3920 C.1 Streams name space
type streamFeatures struct { type streamFeatures struct {
@ -366,7 +366,7 @@ type clientError struct {
} }
// Scan XML token stream to find next StartElement. // Scan XML token stream to find next StartElement.
func nextStart(p *xml.Parser) (xml.StartElement, os.Error) { func nextStart(p *xml.Parser) (xml.StartElement, error) {
for { for {
t, err := p.Token() t, err := p.Token()
if err != nil { if err != nil {
@ -383,7 +383,7 @@ func nextStart(p *xml.Parser) (xml.StartElement, os.Error) {
// Scan XML token stream for next element and save into val. // Scan XML token stream for next element and save into val.
// If val == nil, allocate new element based on proto map. // If val == nil, allocate new element based on proto map.
// Either way, return val. // Either way, return val.
func next(p *xml.Parser) (xml.Name, interface{}, os.Error) { func next(p *xml.Parser) (xml.Name, interface{}, error) {
// Read start element to find out what type we want. // Read start element to find out what type we want.
se, err := nextStart(p) se, err := nextStart(p)
if err != nil { if err != nil {
@ -392,25 +392,41 @@ func next(p *xml.Parser) (xml.Name, interface{}, os.Error) {
// Put it in an interface and allocate one. // Put it in an interface and allocate one.
var nv interface{} var nv interface{}
switch (se.Name.Space+" "+se.Name.Local) { switch se.Name.Space + " " + se.Name.Local {
case nsStream + " features": nv = &streamFeatures{} case nsStream + " features":
case nsStream + " error": nv = &streamError{} nv = &streamFeatures{}
case nsTLS + " starttls": nv = &tlsStartTLS{} case nsStream + " error":
case nsTLS + " proceed": nv = &tlsProceed{} nv = &streamError{}
case nsTLS + " failure": nv = &tlsFailure{} case nsTLS + " starttls":
case nsSASL + " mechanisms": nv = &saslMechanisms{} nv = &tlsStartTLS{}
case nsSASL + " challenge": nv = "" case nsTLS + " proceed":
case nsSASL + " response": nv = "" nv = &tlsProceed{}
case nsSASL + " abort": nv = &saslAbort{} case nsTLS + " failure":
case nsSASL + " success": nv = &saslSuccess{} nv = &tlsFailure{}
case nsSASL + " failure": nv = &saslFailure{} case nsSASL + " mechanisms":
case nsBind + " bind": nv = &bindBind{} nv = &saslMechanisms{}
case nsClient + " message": nv = &clientMessage{} case nsSASL + " challenge":
case nsClient + " presence": nv = &clientPresence{} nv = ""
case nsClient + " iq": nv = &clientIQ{} case nsSASL + " response":
case nsClient + " error": nv = &clientError{} nv = ""
case nsSASL + " abort":
nv = &saslAbort{}
case nsSASL + " success":
nv = &saslSuccess{}
case nsSASL + " failure":
nv = &saslFailure{}
case nsBind + " bind":
nv = &bindBind{}
case nsClient + " message":
nv = &clientMessage{}
case nsClient + " presence":
nv = &clientPresence{}
case nsClient + " iq":
nv = &clientIQ{}
case nsClient + " error":
nv = &clientError{}
default: default:
return xml.Name{}, nil, os.NewError("unexpected XMPP message " + return xml.Name{}, nil, errors.New("unexpected XMPP message " +
se.Name.Space + " <" + se.Name.Local + "/>") se.Name.Space + " <" + se.Name.Local + "/>")
} }
@ -447,7 +463,7 @@ type tee struct {
w io.Writer w io.Writer
} }
func (t tee) Read(p []byte) (n int, err os.Error) { func (t tee) Read(p []byte) (n int, err error) {
n, err = t.r.Read(p) n, err = t.r.Read(p)
if n > 0 { if n > 0 {
t.w.Write(p[0:n]) t.w.Write(p[0:n])