diff --git a/internal/conn/conn.go b/conn.go similarity index 56% rename from internal/conn/conn.go rename to conn.go index 5abb303..b76b769 100644 --- a/internal/conn/conn.go +++ b/conn.go @@ -1,9 +1,8 @@ -package conn +package q3rcon import ( "fmt" "net" - "time" "github.com/charmbracelet/log" ) @@ -12,23 +11,23 @@ type UDPConn struct { conn *net.UDPConn } -func New(host string, port int) (UDPConn, error) { - udpAddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", host, port)) +func newUDPConn(host string, port int) (*UDPConn, error) { + udpAddr, err := net.ResolveUDPAddr("udp4", net.JoinHostPort(host, fmt.Sprintf("%d", port))) if err != nil { - return UDPConn{}, err + return nil, err } conn, err := net.DialUDP("udp4", nil, udpAddr) if err != nil { - return UDPConn{}, err + return nil, err } log.Infof("Outgoing address %s", conn.RemoteAddr()) - return UDPConn{ + return &UDPConn{ conn: conn, }, nil } -func (c UDPConn) Write(buf []byte) (int, error) { +func (c *UDPConn) Write(buf []byte) (int, error) { n, err := c.conn.Write(buf) if err != nil { return 0, err @@ -37,8 +36,7 @@ func (c UDPConn) Write(buf []byte) (int, error) { return n, nil } -func (c UDPConn) ReadUntil(timeout time.Time, buf []byte) (int, error) { - c.conn.SetReadDeadline(timeout) +func (c *UDPConn) Read(buf []byte) (int, error) { rlen, _, err := c.conn.ReadFromUDP(buf) if err != nil { return 0, err @@ -46,7 +44,7 @@ func (c UDPConn) ReadUntil(timeout time.Time, buf []byte) (int, error) { return rlen, nil } -func (c UDPConn) Close() error { +func (c *UDPConn) Close() error { err := c.conn.Close() if err != nil { return err diff --git a/internal/packet/request.go b/internal/packet/request.go deleted file mode 100644 index 06e3833..0000000 --- a/internal/packet/request.go +++ /dev/null @@ -1,36 +0,0 @@ -package packet - -import ( - "bytes" - "fmt" - - "github.com/charmbracelet/log" -) - -const bufSz = 512 - -type Request struct { - magic []byte - password string - buf *bytes.Buffer -} - -func NewRequest(password string) Request { - return Request{ - magic: []byte{'\xff', '\xff', '\xff', '\xff'}, - password: password, - buf: bytes.NewBuffer(make([]byte, bufSz)), - } -} - -func (r Request) Header() []byte { - return append(r.magic, "rcon"...) -} - -func (r Request) Encode(cmd string) []byte { - r.buf.Reset() - r.buf.Write(r.Header()) - r.buf.WriteString(fmt.Sprintf(" %s %s", r.password, cmd)) - log.Debugf("Encoded request: %s", r.buf.String()) - return r.buf.Bytes() -} diff --git a/internal/packet/response.go b/internal/packet/response.go deleted file mode 100644 index 979574a..0000000 --- a/internal/packet/response.go +++ /dev/null @@ -1,13 +0,0 @@ -package packet - -type Response struct { - magic []byte -} - -func NewResponse() Response { - return Response{magic: []byte{'\xff', '\xff', '\xff', '\xff'}} -} - -func (r Response) Header() []byte { - return append(r.magic, "print\n"...) -} diff --git a/q3rcon.go b/q3rcon.go index 8e2255a..ca287c6 100644 --- a/q3rcon.go +++ b/q3rcon.go @@ -1,24 +1,31 @@ package q3rcon import ( - "bytes" "errors" + "fmt" + "io" "net" "strings" "time" "github.com/charmbracelet/log" - - "github.com/onyx-and-iris/q3rcon/internal/conn" - "github.com/onyx-and-iris/q3rcon/internal/packet" ) const respBufSiz = 2048 +type encoder interface { + Encode(cmd string) ([]byte, error) +} + +type decoder interface { + IsValid(buf []byte) bool + Decode(buf []byte) string +} + type Rcon struct { - conn conn.UDPConn - request packet.Request - response packet.Response + conn io.ReadWriteCloser + request encoder + response decoder loginTimeout time.Duration defaultTimeout time.Duration @@ -30,15 +37,15 @@ func New(host string, port int, password string, options ...Option) (*Rcon, erro return nil, errors.New("no password provided") } - conn, err := conn.New(host, port) + conn, err := newUDPConn(host, port) if err != nil { - return nil, err + return nil, fmt.Errorf("error creating UDP connection: %w", err) } r := &Rcon{ conn: conn, - request: packet.NewRequest(password), - response: packet.NewResponse(), + request: newRequest(password), + response: newResponse(), loginTimeout: 5 * time.Second, defaultTimeout: 20 * time.Millisecond, @@ -50,7 +57,7 @@ func New(host string, port int, password string, options ...Option) (*Rcon, erro } if err = r.login(); err != nil { - return nil, err + return nil, fmt.Errorf("error logging in: %w", err) } return r, nil @@ -65,7 +72,7 @@ func (r Rcon) login() error { default: resp, err := r.Send("login") if err != nil { - return err + return fmt.Errorf("error sending login command: %w", err) } if resp == "" { continue @@ -94,9 +101,14 @@ func (r Rcon) Send(cmdWithArgs string) (string, error) { go r.listen(timeout, respChan, errChan) - _, err := r.conn.Write(r.request.Encode(cmdWithArgs)) + encodedCmd, err := r.request.Encode(cmdWithArgs) if err != nil { - return "", err + return "", fmt.Errorf("error encoding command: %w", err) + } + + _, err = r.conn.Write(encodedCmd) + if err != nil { + return "", fmt.Errorf("error writing command to connection: %w", err) } select { @@ -118,7 +130,17 @@ func (r Rcon) listen(timeout time.Duration, respChan chan<- string, errChan chan respChan <- sb.String() return default: - rlen, err := r.conn.ReadUntil(time.Now().Add(timeout), respBuf) + c, ok := r.conn.(*UDPConn) + if !ok { + errChan <- errors.New("connection is not a UDPConn") + return + } + err := c.conn.SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + errChan <- fmt.Errorf("error setting read deadline: %w", err) + return + } + rlen, err := r.conn.Read(respBuf) if err != nil { e, ok := err.(net.Error) if ok { @@ -131,10 +153,8 @@ func (r Rcon) listen(timeout time.Duration, respChan chan<- string, errChan chan } } - if rlen > len(r.response.Header()) { - if bytes.HasPrefix(respBuf, r.response.Header()) { - sb.Write(respBuf[len(r.response.Header()):rlen]) - } + if r.response.IsValid(respBuf[:rlen]) { + sb.WriteString(r.response.Decode(respBuf[:rlen])) } } } diff --git a/request.go b/request.go new file mode 100644 index 0000000..331e148 --- /dev/null +++ b/request.go @@ -0,0 +1,35 @@ +package q3rcon + +import ( + "bytes" + "errors" + "fmt" +) + +const ( + bufSz = 1024 + requestHeader = "\xff\xff\xff\xffrcon" +) + +type request struct { + password string + buf *bytes.Buffer +} + +func newRequest(password string) request { + return request{ + password: password, + buf: bytes.NewBuffer(make([]byte, 0, bufSz)), + } +} + +func (r request) Encode(cmd string) ([]byte, error) { + if cmd == "" { + return nil, errors.New("command cannot be empty") + } + + r.buf.Reset() + r.buf.WriteString(requestHeader) + r.buf.WriteString(fmt.Sprintf(" %s %s", r.password, cmd)) + return r.buf.Bytes(), nil +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..1c4f452 --- /dev/null +++ b/response.go @@ -0,0 +1,21 @@ +package q3rcon + +import "bytes" + +const ( + responseHeader = "\xff\xff\xff\xffprint\n" +) + +type response struct{} + +func newResponse() response { + return response{} +} + +func (r response) IsValid(buf []byte) bool { + return len(buf) > len(responseHeader) && bytes.HasPrefix(buf, []byte(responseHeader)) +} + +func (r response) Decode(buf []byte) string { + return string(buf[len(responseHeader):]) +}