diff options
Diffstat (limited to 'satp/packet.go')
-rw-r--r-- | satp/packet.go | 165 |
1 files changed, 100 insertions, 65 deletions
diff --git a/satp/packet.go b/satp/packet.go index 7e638fa..22cf4ba 100644 --- a/satp/packet.go +++ b/satp/packet.go @@ -31,11 +31,12 @@ package satp import ( - "bytes" "encoding/binary" "errors" + "io" ) +// // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -53,96 +54,130 @@ import ( // | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | // | | // +- Encrypted Portion Authenticated Portion ---+ +// + +const ( + PACKET_BUFFER_SIZE = 16384 +) + +var ( + ErrTooShort = errors.New("packet is too short") +) type PlainPacket struct { - Type uint16 - Payload []byte + buffer [PACKET_BUFFER_SIZE]byte + header []byte + payload []byte } -func NewPlainPacket(PayloadType uint16) (pp *PlainPacket) { +func NewPlainPacket() (pp *PlainPacket) { pp = &PlainPacket{} - pp.Type = PayloadType + pp.header = pp.buffer[:2:2] + pp.payload = pp.buffer[2:2] return } -func (pp *PlainPacket) MarshalBinary() (data []byte, err error) { - buf := &bytes.Buffer{} - if err = binary.Write(buf, binary.BigEndian, pp.Type); err != nil { - return - } - buf.Write(pp.Payload) // returned error is always nil - return buf.Bytes(), nil +func (pp *PlainPacket) SetPayloadType(payloadType uint16) { + binary.BigEndian.PutUint16(pp.header, payloadType) + return } -func (pp *PlainPacket) UnmarshalBinary(data []byte) (err error) { - buf := bytes.NewReader(data) - if err = binary.Read(buf, binary.BigEndian, &(pp.Type)); err != nil { - return - } - if buf.Len() > 0 { - pp.Payload = make([]byte, buf.Len()) - buf.Read(pp.Payload) +func (pp *PlainPacket) GetPayloadType() (payloadType uint16) { + return binary.BigEndian.Uint16(pp.header) +} + +func (pp *PlainPacket) getPacket() (data []byte) { + return pp.buffer[:len(pp.header)+len(pp.payload)] +} + +func (pp *PlainPacket) ReadFrom(r io.Reader) (int64, error) { + n, err := r.Read(pp.payload[:cap(pp.payload)]) + if err != nil && err != io.EOF { + return 0, err } - return + pp.payload = pp.payload[:n] + return int64(n), nil +} + +func (pp *PlainPacket) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(pp.payload) + return int64(n), err } type EncryptedPacket struct { - SequenceNumber uint32 - SenderID uint16 - Mux uint16 - Payload []byte - AuthTag []byte + buffer [PACKET_BUFFER_SIZE]byte + header []byte + payload []byte + authTag []byte } -func NewEncryptedPacket(AuthTagLength uint) (ep *EncryptedPacket) { +func NewEncryptedPacket() (ep *EncryptedPacket) { ep = &EncryptedPacket{} - if AuthTagLength > 0 { - ep.AuthTag = make([]byte, AuthTagLength) - } + ep.header = ep.buffer[:8:8] + ep.payload = ep.buffer[8:] + ep.authTag = nil return } -func (ep *EncryptedPacket) MarshalBinary() (data []byte, err error) { - buf := &bytes.Buffer{} - if err = binary.Write(buf, binary.BigEndian, ep.SequenceNumber); err != nil { - return - } - if err = binary.Write(buf, binary.BigEndian, ep.SenderID); err != nil { - return - } - if err = binary.Write(buf, binary.BigEndian, ep.Mux); err != nil { - return +func (ep *EncryptedPacket) SetAuthTagLength(length int) error { + total := len(ep.payload) + len(ep.authTag) + if length <= 0 { + ep.payload = ep.payload[:total] + ep.authTag = nil } - if len(ep.Payload) >= 2 { // Payload must at least have a payload type which is a uint16 -> 2 bytes - buf.Write(ep.Payload) // returned error is always nil - } else { - return nil, errors.New("Unable to marshal packet: payload is empty/too short") + if length > total { + return ErrTooShort } + ep.payload = ep.payload[:total-length] + ep.authTag = ep.buffer[len(ep.header)+len(ep.payload) : total] + return nil +} - if len(ep.AuthTag) > 0 { - buf.Write(ep.AuthTag) // returned error is always nil - } - return buf.Bytes(), nil +func (ep *EncryptedPacket) GetAuthTagLength() int { + return len(ep.authTag) } -func (ep *EncryptedPacket) UnmarshalBinary(data []byte) (err error) { - buf := bytes.NewReader(data) - if err = binary.Read(buf, binary.BigEndian, &(ep.SequenceNumber)); err != nil { - return - } - if err = binary.Read(buf, binary.BigEndian, &(ep.SenderID)); err != nil { - return - } - if err = binary.Read(buf, binary.BigEndian, &(ep.Mux)); err != nil { - return - } +func (ep *EncryptedPacket) SetSequenceNumber(sequenceNumber uint32) { + binary.BigEndian.PutUint32(ep.header[:4], sequenceNumber) + return +} - if (buf.Len() - len(ep.AuthTag)) < 2 { // 2 bytes payload type + length of AuthTag must be left on buffer - return errors.New("Unable to unmarshal packet: too short") - } +func (ep *EncryptedPacket) GetSequenceNumber() (sequenceNumber uint32) { + return binary.BigEndian.Uint32(ep.header[:4]) +} + +func (ep *EncryptedPacket) SetSenderID(senderID uint16) { + binary.BigEndian.PutUint16(ep.header[4:6], senderID) + return +} - ep.Payload = make([]byte, buf.Len()-len(ep.AuthTag)) - buf.Read(ep.Payload) - buf.Read(ep.AuthTag) +func (ep *EncryptedPacket) GetSenderID() (senderID uint16) { + return binary.BigEndian.Uint16(ep.header[4:6]) +} + +func (ep *EncryptedPacket) SetMux(mux uint16) { + binary.BigEndian.PutUint16(ep.header[6:8], mux) return } + +func (ep *EncryptedPacket) GetMux() (mux uint16) { + return binary.BigEndian.Uint16(ep.header[6:8]) +} + +func (ep *EncryptedPacket) ReadFrom(r io.Reader) (int64, error) { + n, err := r.Read(ep.buffer[:]) + if err != nil && err != io.EOF { + return 0, err + } + if n < len(ep.header)+2+len(ep.authTag) { + return 0, ErrTooShort + } + ep.payload = ep.payload[:n-len(ep.header)] + err = ep.SetAuthTagLength(len(ep.authTag)) + return int64(n), err +} + +func (ep *EncryptedPacket) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(ep.buffer[:len(ep.header)+len(ep.payload)+len(ep.authTag)]) + return int64(n), err +} |