MQTTサーバーを実装しながらGoを学ぶ - その6 const, iota

前回の続きです。handlerのエラーハンドリングからやります。その後、mosquitto_clientからPUBLISHパケットを自作サーバーで受け取れるようにしました。最後にhandlerのリファクタリングで Untyped constant declaration というconstの便利な使い方を知りました。

今回学ぶこと。

  • handlerのエラーハンドリング
  • PUBLISHパケットとDISCONNECTパケット
  • Goのconstとiota

handlerでのエラーハンドリング

handlerでのエラーハンドリングを実装する。前々回調べた通りでError TypeもしくはOpaque Patternを使う。

すでに handler → packet という依存関係ができているので、 packet パッケージにError Typeを作ることにする。 ConnectError というインタフェースを用意する。このインタフェースに Error() を持たせてError Typeとして、さらに Connack を取得するためのメソッドも追加する。

--- a/study/packet/connack.go
+++ b/study/packet/connack.go
@@ -28,29 +28,47 @@ func (c Connack) ToBytes() []byte {
        return result
 }
 
+func newConnack() Connack {
+       fixedHeader := FixedHeader{
+               PacketType:      2,
+               RemainingLength: 2,
+       }
+       variableHeader := ConnackVariableHeader{SessionPresent: false}
+       return Connack{fixedHeader, variableHeader}
+}
+
 func NewConnackForAccepted() Connack {
        result := newConnack()
        result.ReturnCode = 0
        return result
 }
 
-func NewConnackForRefusedByUnacceptableProtocolVersion() Connack {
-       result := newConnack()
-       result.ReturnCode = 1
-       return result
+type ConnectError interface {
+       Connack() Connack
+       Error() string
 }
 
-func NewConnackForRefusedByIdentifierRejected() Connack {
-       result := newConnack()
-       result.ReturnCode = 2
-       return result
+type connectError struct {
+       connack Connack
+       msg     string
 }
 
-func newConnack() Connack {
-       fixedHeader := FixedHeader{
-               PacketType:      2,
-               RemainingLength: 2,
-       }
-       variableHeader := ConnackVariableHeader{SessionPresent: false}
-       return Connack{fixedHeader, variableHeader}
+func (e connectError) Connack() Connack {
+       return e.connack
+}
+
+func (e connectError) Error() string {
+       return e.msg
+}
+
+func RefusedByUnacceptableProtocolVersion(s string) ConnectError {
+       connack := newConnack()
+       connack.ReturnCode = 1
+       return connectError{connack, s}
+}
+
+func RefusedByIdentifierRejected(s string) ConnectError {
+       connack := newConnack()
+       connack.ReturnCode = 2
+       return connectError{connack, s}
 }
--- a/study/packet/connect_payload.go
+++ b/study/packet/connect_payload.go
@@ -5,8 +5,6 @@ import (
        "encoding/binary"
        "io"
        "regexp"
-
-       "github.com/pkg/errors"
 )
 
 type ConnectPayload struct {
@@ -30,10 +28,10 @@ func ToConnectPayload(r *bufio.Reader) (ConnectPayload, error) {
        }
        clientID := string(clientIDBytes)
        if len(clientID) < 1 || len(clientID) > 23 {
-               return ConnectPayload{}, errors.New("ClientID length is invalid")
+               return ConnectPayload{}, RefusedByIdentifierRejected("ClientID length is invalid")
        }
        if !clientIDRegex.MatchString(clientID) {
-               return ConnectPayload{}, errors.New("clientId format shoud be \"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\"")
+               return ConnectPayload{}, RefusedByIdentifierRejected("ClientId format shoud be \"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\"")
        }
        return ConnectPayload{ClientID: clientID}, nil
 }
--- a/study/packet/connect_variable_header.go
+++ b/study/packet/connect_variable_header.go
@@ -30,11 +30,11 @@ func ToConnectVariableHeader(fixedHeader FixedHeader, r *bufio.Reader) (ConnectV
        protocolName := make([]byte, 6)
        _, err := io.ReadFull(r, protocolName)
        if err != nil || !isValidProtocolName(protocolName) {
-               return ConnectVariableHeader{}, errors.New("protocol name is invalid")
+               return ConnectVariableHeader{}, RefusedByUnacceptableProtocolVersion("protocol name is invalid")
        }
        protocolLevel, err := r.ReadByte()
        if err != nil || protocolLevel != 4 {
-               return ConnectVariableHeader{}, errors.New("protocol level must be 4")
+               return ConnectVariableHeader{}, RefusedByUnacceptableProtocolVersion("protocol level must be 4")
        }
 
        // TODO

handlerパッケージの変更。返された errorpacket.ConnectError だった場合は、 Connack を取得して返すように変更する。

--- a/study/handler/connect_handler.go
+++ b/study/handler/connect_handler.go
@@ -13,14 +13,18 @@ var variableHeaderLength = 10
 func HandleConnect(fixedHeader packet.FixedHeader, r *bufio.Reader) (packet.Connack, error) {
        variableHeader, err := packet.ToConnectVariableHeader(fixedHeader, r)
        if err != nil {
-               // TODO err応じたCONNACKを生成して返す
-               return packet.NewConnackForRefusedByUnacceptableProtocolVersion(), nil
+               if ce, ok := err.(packet.ConnectError); ok {
+                       return ce.Connack(), nil
+               }
+               return packet.Connack{}, err
        }

        payload, err := packet.ToConnectPayload(r)
        if err != nil {
-               // TODO err応じたCONNACKを生成して返す
-               return packet.NewConnackForRefusedByIdentifierRejected(), nil
+               if ce, ok := err.(packet.ConnectError); ok {
+                       return ce.Connack(), nil
+               }
+               return packet.Connack{}, err
        }

        // TODO variableHeaderとpayloadを使って何かしらの処理

PUBLISHパケットとDISCONNECTパケット

やりたかったことを思い出す。最初の回に書いたように、まず実現したいのは以下のフロー。

  1. クライアント → Connect Command → サーバー
  2. クライアント ← Connect Ack ← サーバー
  3. クライアント → Publish Message → サーバー
  4. クライアント → Disconnect Req → サーバー

ここまでで、1と2のCONNECT(Connect Command)とCONNACK(Connect Ack)はできた。

次は、PUBLISH(Publish Message)とDISSCONNECT(Disconnect Req)に取り掛かる。

PUBLISHパケット

PUBLISHパケットの可変ヘッダー

PUBLISHパケットの可変ヘッダーは以下の情報を持つ。

http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718039

  • Topic Name
  • Packet Identifier

クライアントはPUBLISHパケットの Topic Name でトピックを指定してメッセージをサーバー(MQTT Broker)に送る。サーバーは、(まだ未実装だけど)PUBLISHで指定されたトピックをサブスクライブしてるクライアントにメッセージを転送する。

例えば以下のようなイメージ。

https://s3.amazonaws.com/www.appcelerator.com.images/MQTT_1.png

引用元:https://www.appcelerator.com/blog/2018/03/api-builder-and-mqtt-for-iot-part-1/

今はサブスクライブしてるクライアントのことは考えてないので、上の図でいうと左側の "temp" というトピックに対して "75°F" というメッセージをPUBLISHしてる部分を実装する。

Topic Name について以下のような記述がある。Topic Nameのワイルドカードというのは #+ の2文字。サブスクライブ時にワイルドカードを指定することで複数のTopicをサブスクライブすることができる。Publishの可変ヘッダーにはこれらの文字を含んではいけない。

The Topic Name in the PUBLISH Packet MUST NOT contain wildcard characters

Packet Identifier はQoS1かQoS2の場合に使う。今はQoS0固定で考えているので後回し。

可変ヘッダーを実装する。

package packet

import (
    "bufio"
    "fmt"
    "io"
    "strings"
)

type PublishVariableHeader struct {
    TopicName        string
    PacketIdentifier *uint16
}

func ToPublishVariableHeader(fixedHeader FixedHeader, r *bufio.Reader) (PublishVariableHeader, error) {
    if fixedHeader.PacketType != 3 {
        return PublishVariableHeader{}, fmt.Errorf("packet type is invalid. it got is %v", fixedHeader.PacketType)
    }

    _, err := r.ReadByte()
    if err != nil {
        return PublishVariableHeader{}, err
    }
    lengthLSB, err := r.ReadByte()
    if err != nil {
        return PublishVariableHeader{}, err
    }
    if lengthLSB == 0 {
        return PublishVariableHeader{}, fmt.Errorf("length LSB should be > 0")
    }
    topicNameBytes := make([]byte, lengthLSB)
    _, err = io.ReadFull(r, topicNameBytes)
    if err != nil {
        return PublishVariableHeader{}, err
    }
    topicName := string(topicNameBytes)
    if strings.ContainsAny(topicName, "# +") {
        return PublishVariableHeader{}, fmt.Errorf("topic name must not contain wildcard. it got is %v", topicName)
    }

    result := PublishVariableHeader{string(topicNameBytes), nil}
    return result, nil
}

PUBLISHパケットのペイロード

ペイロードは、サブスクライバーに対して送信するメッセージそのもの。 bufio.Reader をそのまま使うことにする。

PUBLISHパケットに対するレスポンス

QoS0の時はレスポンスなし。

DISCONNECTパケット

DISCONNECTパケットの仕様はこちら

http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718090

DISCONNECTパケットには可変ヘッダーもペイロードもない。

serverとhandler実装

前回のserver.goはCONNECTパケットしか想定していなかった。PUBLISHパケットとDISCONNECTパケットも想定した実装にする。

server.goの実装。 Accept() でクライアントからの接続を待つ。接続が来て net.Conn を取得できたら handle() 関数へ渡す。 handle() 関数内でループしてクライアントからのMQTTパケットを受け取り FixedHeader を生成。パケットタイプによるswitchで処理を分岐し、それぞれのパケットに対応したhandlerを呼ぶ。DISCONNECTパケットの場合はコネクションを切るだけなので return nil だけして defer で接続を切る。

package study

import (
    "bufio"
    "fmt"
    "io"
    "net"

    "github.com/bati11/oreno-mqtt/study/handler"
    "github.com/bati11/oreno-mqtt/study/packet"
)

func Run() {
    ln, err := net.Listen("tcp", "localhost:1883")
    if err != nil {
        panic(err)
    }
    fmt.Println("server starts at localhost:1883")

    for {
        conn, err := ln.Accept()
        if err != nil {
            panic(err)
        }

        err = handle(conn)
        if err != nil {
            panic(err)
        }
    }
}

func handle(conn net.Conn) error {
    defer conn.Close()

    for {
        r := bufio.NewReader(conn)
        fixedHeader, err := packet.ToFixedHeader(r)
        if err != nil {
            if err == io.EOF {
                // クライアント側から既に切断してる場合
                return nil
            }
            return err
        }
        fmt.Printf("-----\n%+v\n", fixedHeader)

        switch fixedHeader.PacketType {
        // CONNECT
        case 1:
            connack, err := handler.HandleConnect(fixedHeader, r)
            if err != nil {
                return err
            }
            _, err = conn.Write(connack.ToBytes())
            if err != nil {
                return err
            }
        // PUBLISH
        case 3:
            err := handler.HandlePublish(fixedHeader, r)
            if err != nil {
                return err
            }
        // DISCONNECT
        case 14:
            return nil
        }
    }
}

HandleConnect() は前回から変更なし。

HandlePublish() は以下のように実装。

package handler

import (
    "bufio"
    "fmt"
    "io"

    "github.com/bati11/oreno-mqtt/study/packet"
)

func HandlePublish(fixedHeader packet.FixedHeader, r *bufio.Reader) error {
    fmt.Printf("  HandlePublish\n")
    variableHeader, err := packet.ToPublishVariableHeader(fixedHeader, r)
    if err != nil {
        return err
    }
    fmt.Printf("  %#v\n", variableHeader)

    payloadLength := fixedHeader.RemainingLength - variableHeader.Length
    payload := make([]byte, payloadLength)
    _, err = io.ReadFull(r, payload)
    if err != nil {
        return err
    }
    fmt.Printf("  Payload: %v\n", string(payload))

    // TODO QoS0なのでレスポンスなし
    return nil
}

mosquittoクライアントからメッセージを送ってみる

サーバーを起動する。

$ go run app/main.go 
server starts at localhost:1883

Wiresharkを起動しておく。

mosquittoクライアントからpublishする。

$ mosquitto_pub -t hoge -m "Hello"

結果を見てみる。

https://i.gyazo.com/38af18d974d97c8694564f9eb8f2320a.png

お、できてそう!サーバーの標準出力も確認。

-----
{PacketType:1 Dup:0 QoS1:0 QoS2:0 Retain:0 RemainingLength:28}
HandleConnect
  packet.ConnectVariableHeader{ProtocolName:"MQTT", ProtocolLevel:0x4, ConnectFlags:packet.ConnectFlags{CleanSession:true, WillFlag:true, WillQoS:0x1, WillRetain:false, PasswordFlag:true, UserNameFlag:true}, KeepAlive:0xa}
  packet.ConnectPayload{ClientID:"custom-client-id"}
-----
{PacketType:3 Dup:0 QoS1:0 QoS2:0 Retain:0 RemainingLength:11}
  HandlePublish
  packet.PublishVariableHeader{TopicName:"hoge", PacketIdentifier:(*uint16)(nil), Length:0x6}
  Payload: Hello

ちゃんとPUBLISHパケットを解釈できてる!

これで最初の目標の以下の流れが実現できた。

  1. クライアント → Connect Command → サーバー
  2. クライアント ← Connect Ack ← サーバー
  3. クライアント → Publish Message → サーバー
  4. クライアント → Disconnect Req → サーバー

const

ところで、server.goの fixedHeader.PacketType の値、マジックナンバーでswitchしてるところを分かりやすくしたい。

switch fixedHeader.PacketType {
// CONNECT
case 1:
    connack, err := handler.HandleConnect(fixedHeader, r)
    if err != nil {
        return err
    }
    _, err = conn.Write(connack.ToBytes())
    if err != nil {
        return err
    }
// PUBLISH
case 3:
    err := handler.HandlePublish(fixedHeader, r)
    if err != nil {
        return err
    }
// DISCONNECT
case 14:
    return nil
}

マジックナンバーをconstで置き換える。定数化することでコードが読みやすくなる。また、定数は実行時ではなくコンパイル時にコンパイラが最適化してくれて、パフォーマンスが良くなる場合もある。

fixed_header.goに定数を定義。

+const (
+       CONNECT    byte = 1
+       PUBLISH    byte = 3
+       DISCONNECT byte = 14
+)
+
 type FixedHeader struct {
        PacketType      byte
        Dup             byte

server.goのswitchは以下のようになる。

switch fixedHeader.PacketType {
case packet.CONNECT:
    connack, err := handler.HandleConnect(fixedHeader, r)
    if err != nil {
        return err
    }
    _, err = conn.Write(connack.ToBytes())
    if err != nil {
        return err
    }
case packet.PUBLISH:
    err := handler.HandlePublish(fixedHeader, r)
    if err != nil {
        return err
    }
case packet.DISCONNECT:
    return nil
}

Untyped constant declaration

定数の型を byte ではなく int にするとどうなるか?

const (
       CONNECT    int = 1
       PUBLISH    byte = 3
       DISCONNECT byte = 14
)

server.goのswitch-caseのところで以下のようなコンパイルエラーになる。fixedHeader.PacketType の型が byte なのに、 int 型の定数と比較してるためコンパイルエラー。

server.go:49:3: invalid case packet.CONNECT in switch on fixedHeader.PacketType (mismatched types int and byte)

では、定数の型を 書かない 場合はどうなるか?

const (
    CONNECT         = 1
    PUBLISH    byte = 3
    DISCONNECT byte = 14
)

これだとコンパイルエラーにならない。 しかも、先程と異なり、 int と比較してる箇所もコンパイルエラーにならない!

コンパイル時に適切な精度の型として埋め込んでくれるらしい。先ほどの記事ではこれを「 Untyped constant declaration 」と呼んでいる。

必要がない限り、型指定なしでconstを定義した方が良い。

const (
    CONNECT    = 1
    PUBLISH    = 3
    DISCONNECT = 14
)

iota

constsについて、Effective Goも読んでみると以下のように書いてある。

In Go, enumerated constants are created using the iota enumerator.

enumというとJavaの列挙型を思い出す。けど、それは一旦置いておいてGoでは iota という演算子を使うことで定数の定義を簡単かつ柔軟にできる。

PacketTypeの値の定数定義は iota を使って以下のように書ける。

const (
    _ = iota
    CONNECT
    CONNACK
    PUBLISH
    PUBACK
    PUBREC
    PUBREL
    PUBCOMP
    SUBSCRIBE
    SUBACK
    UNSUBSCRIBE
    UNSUBACK
    PINGREQ
    PINGRESP
    DISCONNECT
)

さらに、型を作ってメソッドを定義すれば、Enum(列挙型)のように使うこともできる。

type PacketType byte

const (
    _ PacketType = iota
    CONNECT
    CONNACK
    PUBLISH
    PUBACK
    PUBREC
    PUBREL
    PUBCOMP
    SUBSCRIBE
    SUBACK
    UNSUBSCRIBE
    UNSUBACK
    PINGREQ
    PINGRESP
    DISCONNECT
)

func (v PacketType) String() string {
    names := [...]string{
        "CONNECT",
        "CONNACK",
        "PUBLISH",
        "PUBACK",
        "PUBREC",
        "PUBREL",
        "PUBCOMP",
        "SUBSCRIBE",
        "SUBACK",
        "UNSUBSCRIBE",
        "UNSUBACK",
        "PINGREQ",
        "PINGRESP",
        "DISCONNECT"}
    if v < CONNECT || v > DISCONNECT {
        return "Unknown"
    }
    return names[v]
}

Enumについては以下の記事が詳しい。

どういう時にEnumが必要なのか?が書いてある。

Why do we need enums?

Grouping and expecting only some related values
Sharing common behavior
Avoids using invalid values
To increase the code readability and the maintainability

また、iotaを使うと定義する順番を間違えると値が変わってしまうため、値に意味があると問題になる場合もある。

いまのところ型や共通のメソッドは不要であるので、Enumはいらない気がする。値を定義する順番については、プロトコルで決まっている値で変更される頻度がほとんどないので気にしないでおく。結果、 iota を使って型指定なしのconstを定義する形にする。

const (
    _ = iota
    CONNECT
    CONNACK
    PUBLISH
    PUBACK
    PUBREC
    PUBREL
    PUBCOMP
    SUBSCRIBE
    SUBACK
    UNSUBSCRIBE
    UNSUBACK
    PINGREQ
    PINGRESP
    DISCONNECT
)

おしまい

ここまでで、MQTTクライアントからメッセージをPublishするところまでができました。次はいよいよ別のクライアントがSubscribeところに着手します。goroutineをどう使ってクライアントを管理するのかを考えていきます。

今回の学び。

MQTTサーバーを実装しながらGoを学ぶ - その5 net, io, bufioパッケージ

前回の続きです。前回までで一応 CONNECTパケットをstructに変換する処理ができました。これでクライアントからのCONNECTパケットをサーバー側で解釈できます。

今回は、CONNECTに対するレスポンスであるCONNACKに取り掛かります。また、実際にサーバーとして動かし、mosquitto_clientと実際にMQTT通信(CONNECT->CONNACK)ができるようにします。その過程で、今まで単純に []byte として扱っていた部分を io.Reader に書き直すことになりました。

目次。

CONNACKパケット

CONNECTパケットを受け取ったサーバーは、クライアントにCONNACKパケットを返す。

CONNACKパケットは、固定ヘッダーと可変ヘッダーから構成される。ペイロードはない。可変ヘッダーは2byte。なので固定ヘッダーの ReminingLength は2で固定値となる。

Session Present と Connect Return Code

CONNACKパケットの可変ヘッダーは、以下の情報を持つ。

  • Session Present
  • Connect Return code

Session PresentはCONNECTパケットで指定されたClientIDとのセッションがサーバーに管理されているかどうかを示す。まだセッションについては考えれていないので常に 0 をセットすることにする。

Connect Return codeは、何通りかある。正常な場合は 0x00 。クライアントから指定されたMQTTのバージョンを受け入れられない場合は 0x01 。クライアントから指定されたClient Identifierを受け入れられない場合は 0x02 。現時点ではこの3パターンを実装する。

以下、Connactパケットの実装。

type ConnackVariableHeader struct {
    SessionPresent bool
    ReturnCode     uint8
}

type Connack struct {
    FixedHeader
    ConnackVariableHeader
}

func NewConnackForAccepted() Connack {
    result := newConnack()
    result.ReturnCode = 0
    return result
}

func NewConnackForRefusedByUnacceptableProtocolVersion() Connack {
    result := newConnack()
    result.ReturnCode = 1
    return result
}

func NewConnackForRefusedByIdentifierRejected() Connack {
    result := newConnack()
    result.ReturnCode = 2
    return result
}

func newConnack() Connack {
    fixedHeader := FixedHeader{
        PacketType:      2,
        RemainingLength: 2,
  }

    // TODO SessionPresentは固定にしておく
    variableHeader := ConnackVariableHeader{SessionPresent: false}

    return Connack{fixedHeader, variableHeader}
}

struct から []byte へ変換

サーバーからクライアントへCONNACKパケットを返す際に、 Connack structをバイト列に変換する必要がある。 ToBytes() メソッドを実装する。

func (c Connack) ToBytes() []byte {
    var result []byte
    result = append(result, c.FixedHeader.ToBytes()...)
    result = append(result, c.ConnackVariableHeader.ToBytes()...)
    return result
}

FixedHeaderConnackVariableHeader にも ToBytes() メソッドを実装する。

ConnackVariableHeader

func (h ConnackVariableHeader) ToBytes() []byte {
    var result []byte
    if h.SessionPresent {
        result = append(result, 1)
    } else {
        result = append(result, 0)
    }
    result = append(result, h.ReturnCode)
    return result
}

FixedHeader

Remining Lengthのencodeのロジックは仕様に書いてある

func (h FixedHeader) ToBytes() []byte {
    var result []byte
    b := h.PacketType << 4
    result = append(result, b)
    remainingLength := encodeRemainingLength(h.RemainingLength)
    result = append(result, remainingLength...)
    return result
}

func encodeRemainingLength(x uint) []byte {
    var encodedByte byte
    var result []byte
    for {
        encodedByte = byte(x % 128)
        x = x / 128
        if x > 0 {
            encodedByte = encodedByte | 128
        }
        result = append(result, encodedByte)
        if x <= 0 {
            break
        }
    }
    return result
}

メソッドのレシーバー

メソッドのレシーバーをstructにするべきかstructのポインタにするべきか、という話がA Tour of GoやEffective Goに書いてある。

個人的には可能な限り不変にしたいのでメソッドのレシーバーはポインタにしたくないけど、不変にしたいならstructのフィールドを全部プライベートにするところまでやらないと意味ないよなぁ。参照はできるけど変更はできない修飾子があったら嬉しい、けどGoにはない。

サーバー実装

サーバーを実装する。複数クライアントを同時に捌くことは後で考える。まずは以下の流れを実現する。

  1. TCPをListen
  2. 固定ヘッダーを取り出す
  3. PacketTypeに応じたhandlerに処理を移譲する
  4. CONNECTパケットを解釈して処理する
  5. CONNACKパケットをクライアントに返す
  6. コネクションを切断する

TCPサーバーは以前にechoサーバーを作ったのでそれを参考に。

bati11blog.hatenablog.com

netパッケージの ListenAccept を使って、 net.Conn を取得する。 net.Conn を使ってクライアントから送られてきたバイト列を取得する。

net package - net - Go Packages

[]byteからbufio.Readerに変更する

サーバーを実装しようと思ったら早速困ったことに。

net.Conn からバイト列を取得して、これまでに実装してきた packet.ToFixedHeader(bs []byte) を呼び出したいが、この []byte の長さはどうしたらいいのだろう。。ToFixedHeader関数でバイト列を引数にとるようにしてるのが設計ミスっぽい。 []byte ではなく bufio.Reader を引数にとるように変更する。

io.Reader 周りは以下が参考になる。

まずはToFixedHeaderの引数を変更する

- func ToFixedHeader(bs []byte) (FixedHeader, error)
+ func ToFixedHeader(r *bufio.Reader) (FixedHeader, error)

1バイトの取得は ReadByte() を使う。 https://golang.org/pkg/bufio/#Reader.ReadByte

- b := bs[0]
+ b, err := r.ReadByte()

ToConnectVariableHeaderの引数も変更する

- func ToConnectVariableHeader(fixedHeader FixedHeader, bs []byte) (ConnectVariableHeader, error)
+ func ToConnectVariableHeader(fixedHeader FixedHeader, r *bufio.Reader) (ConnectVariableHeader, error)

nバイトの取得は io.ReadFull を使う。 https://golang.org/pkg/io/#ReadFull

- if !isValidProtocolName(bs[:6]) {`
+ protocolName := make([]byte, 6)
+ _, err := io.ReadFull(r, protocolName)
+ if err != nil || !isValidProtocolName(protocolName) {

ToConnectPayloadの引数も変更する。

- func ToConnectPayload(bs []byte) (ConnectPayload, error)
+ func ToConnectPayload(r *bufio.Reader) (ConnectPayload, error)

diff全体は大きいので最後に載っけておく。

テストでbufio.Readerをどうするか

[]byte から bufio.Reader に変えたことで、テストコードにおいて引数として []byte{ 0x10, 0x00 } というような値を渡せなくなってしまった。困った。

こういう時は bytes.Buffer を使う。 bytes.Buffer[]byte を保持できて、かつ io.Reader インタフェースを満たしている。

- []byte{0x1B, 0x7F}
+ bufio.NewReader(bytes.NewBuffer([]byte{0x1B, 0x7F}))

CONNECTパケットのhandlerとサーバーを実装する

  • handler/connect_handler.goを実装。エラーハンドリングは次回やる。
  • server.goを実装
  • main関数も実装

connect_handler.go

package handler

import (
    "bufio"
    "fmt"

    "github.com/bati11/oreno-mqtt/mqtt/packet"
)

// CONNECTパケットの可変ヘッダーのバイト数
var variableHeaderLength = 10

func HandleConnect(fixedHeader packet.FixedHeader, r *bufio.Reader) (packet.Connack, error) {
    fmt.Printf("HandleConnect\n")
    variableHeader, err := packet.ToConnectVariableHeader(fixedHeader, r)
    if err != nil {
        // TODO err応じたCONNACKを生成して返す
        return packet.NewConnackForRefusedByUnacceptableProtocolVersion(), nil
    }

    payload, err := packet.ToConnectPayload(r)
    if err != nil {
        // TODO err応じたCONNACKを生成して返す
        return packet.NewConnackForRefusedByIdentifierRejected(), nil
    }

    // TODO variableHeaderとpayloadを使って何かしらの処理
    fmt.Printf("  %#v\n", variableHeader)
    fmt.Printf("  %#v\n", payload)

    return packet.NewConnackForAccepted(), nil
}

server.go

package mqtt

import (
    "bufio"
    "fmt"
    "net"

    "github.com/bati11/oreno-mqtt/mqtt/handler"
    "github.com/bati11/oreno-mqtt/mqtt/packet"
)

func Run() {
    ln, err := net.Listen("tcp", "localhost:1883")
    if err != nil {
        panic(err)
    }
    fmt.Println("server starts at localhost:1883")
    conn, err := ln.Accept()
    if err != nil {
        panic(err)
    }
    defer conn.Close()

    r := bufio.NewReader(conn)
    fixedHeader, err := packet.ToFixedHeader(r)
    if err != nil {
        panic(err)
    }

    switch fixedHeader.PacketType {
    case 1:
        connack, err := handler.HandleConnect(fixedHeader, r)
        if err != nil {
            panic(err)
        }
        _, err = conn.Write(connack.ToBytes())
        if err != nil {
            panic(err)
        }
    }
}

main.go

package main

import "github.com/bati11/oreno-mqtt/mqtt"

func main() {
    mqtt.Run()
}

CONNECTパケットを送信する

モリモリと実装したので、実際にMQTTパケットを送って動作確認をする。

初回でやった $ mosquitto_pub -t hoge -m "Hello" を自作のサーバーに対して送信する。Wiresharkを起動しておく。

自作サーバーを起動する。

$ go run app/main.go
server starts at localhost:1883

CONNECTパケットを送信する。

$ mosquitto_pub -t hoge -m "Hello"

Wiresharkを見てみると...

できたー!!!ちゃんとCONNACKパケットが返ってる!

おしまい

まず、CONNACKパケットを表すstructを実装しました。その次に、 []byte から各種structを生成する処理を bufio.Reader から各種structを生成するように実装を書き換えました。最後に、TCPサーバーを起動して、 mosquitto_pub コマンドでCONNECTパケットを実際に送りました。

次回はエラーハンドリングのTODOのところから。

今回の学び


[]byte から bufio.Reader に書き換えた時の差分

--- a/mqtt/packet/connect_payload.go
+++ b/mqtt/packet/connect_payload.go
@@ -1,7 +1,9 @@
 package packet
 
 import (
+       "bufio"
        "encoding/binary"
+       "io"
        "regexp"
 
        "github.com/pkg/errors"
@@ -13,17 +15,20 @@ type ConnectPayload struct {
 
 var clientIDRegex = regexp.MustCompile("^[a-zA-Z0-9-|]*$")
 
-func ToConnectPayload(bs []byte) (ConnectPayload, error) {
-       if len(bs) < 3 {
-               return ConnectPayload{}, errors.New("payload length is invalid")
+func ToConnectPayload(r *bufio.Reader) (ConnectPayload, error) {
+       lengthBytes := make([]byte, 2)
+       _, err := io.ReadFull(r, lengthBytes)
+       if err != nil {
+               return ConnectPayload{}, err
        }
-       length := binary.BigEndian.Uint16(bs[0:2])
-       var clientID string
-       if len(bs) < 2+int(length) {
-               return ConnectPayload{}, errors.New("specified length is not equals ClientID length")
-       } else {
-               clientID = string(bs[2 : 2+length])
+       length := binary.BigEndian.Uint16(lengthBytes)
+
+       clientIDBytes := make([]byte, length)
+       _, err = io.ReadFull(r, clientIDBytes)
+       if err != nil {
+               return ConnectPayload{}, err
        }
+       clientID := string(clientIDBytes)
        if len(clientID) < 1 || len(clientID) > 23 {
                return ConnectPayload{}, errors.New("ClientID length is invalid")
        }
--- a/mqtt/packet/connect_payload_test.go
+++ b/mqtt/packet/connect_payload_test.go
@@ -1,13 +1,15 @@
 package packet
 
 import (
+       "bufio"
+       "bytes"
        "reflect"
        "testing"
 )
 
 func TestToConnectPayload(t *testing.T) {
        type args struct {
-               bs []byte
+               r *bufio.Reader
        }
        tests := []struct {
                name    string
@@ -17,38 +19,38 @@ func TestToConnectPayload(t *testing.T) {
        }{
                {
                        name:    "ClientIDが1文字",
-                       args:    args{[]byte{0x00, 0x01, 'a'}},
+                       args:    args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x01, 'a'}))},
                        want:    ConnectPayload{ClientID: "a"},
                        wantErr: false,
                },
                {
                        name:    "ペイロードが0byte",
-                       args:    args{[]byte{}},
+                       args:    args{bufio.NewReader(bytes.NewBuffer([]byte{}))},
                        want:    ConnectPayload{},
                        wantErr: true,
                },
                {
                        name:    "ClientIDが23文字を超える",
-                       args:    args{[]byte{0x00, 0x18, '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd'}},
+                       args:    args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x18, '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd'}))},
                        want:    ConnectPayload{},
                        wantErr: true,
                },
                {
                        name:    "使えない文字がある",
-                       args:    args{[]byte{0x00, 0x02, '1', '%'}},
+                       args:    args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x02, '1', '%'}))},
                        want:    ConnectPayload{},
                        wantErr: true,
                },
                {
                        name:    "指定された長さよりも実際に取得できたClientIDが短い",
-                       args:    args{[]byte{0x00, 0x03, '1', '2'}},
+                       args:    args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x03, '1', '2'}))},
                        want:    ConnectPayload{},
                        wantErr: true,
                },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       got, err := ToConnectPayload(tt.args.bs)
+                       got, err := ToConnectPayload(tt.args.r)
                        if (err != nil) != tt.wantErr {
                                t.Errorf("ToConnectPayload() error = %v, wantErr %v", err, tt.wantErr)
                                return
--- a/mqtt/packet/connect_variable_header.go
+++ b/mqtt/packet/connect_variable_header.go
@@ -1,6 +1,9 @@
 package packet
 
 import (
+       "bufio"
+       "io"
+
        "github.com/pkg/errors"
 )
 
@@ -20,16 +23,34 @@ type ConnectVariableHeader struct {
        KeepAlive     uint16
 }
 
-func ToConnectVariableHeader(fixedHeader FixedHeader, bs []byte) (ConnectVariableHeader, error) {
+func ToConnectVariableHeader(fixedHeader FixedHeader, r *bufio.Reader) (ConnectVariableHeader, error) {
        if fixedHeader.PacketType != 1 {
                return ConnectVariableHeader{}, errors.New("fixedHeader.PacketType must be 1")
        }
-       if !isValidProtocolName(bs[:6]) {
+       protocolName := make([]byte, 6)
+       _, err := io.ReadFull(r, protocolName)
+       if err != nil || !isValidProtocolName(protocolName) {
                return ConnectVariableHeader{}, errors.New("protocol name is invalid")
        }
-       if bs[6] != 4 {
+       protocolLevel, err := r.ReadByte()
+       if err != nil || protocolLevel != 4 {
                return ConnectVariableHeader{}, errors.New("protocol level must be 4")
        }
+
+       // TODO
+       _, err = r.ReadByte() // connectFlags
+       if err != nil {
+               return ConnectVariableHeader{}, err
+       }
+       _, err = r.ReadByte() // keepAlive MSB
+       if err != nil {
+               return ConnectVariableHeader{}, err
+       }
+       _, err = r.ReadByte() // keepAlive LSB
+       if err != nil {
+               return ConnectVariableHeader{}, err
+       }
+
        return ConnectVariableHeader{
                ProtocolName:  "MQTT",
                ProtocolLevel: 4,
--- a/mqtt/packet/connect_variable_header_test.go
+++ b/mqtt/packet/connect_variable_header_test.go
@@ -1,6 +1,8 @@
 package packet_test
 
 import (
+       "bufio"
+       "bytes"
        "reflect"
        "testing"
 
@@ -10,7 +12,7 @@ import (
 func TestToConnectVariableHeader(t *testing.T) {
        type args struct {
                fixedHeader packet.FixedHeader
-               bs          []byte
+               r           *bufio.Reader
        }
        tests := []struct {
                name    string
@@ -22,12 +24,12 @@ func TestToConnectVariableHeader(t *testing.T) {
                        name: "仕様書のexample",
                        args: args{
                                fixedHeader: packet.FixedHeader{PacketType: 1},
-                               bs: []byte{
+                               r: bufio.NewReader(bytes.NewBuffer([]byte{
                                        0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name
                                        0x04,       // Protocol Level
                                        0xCE,       // Connect Flags
                                        0x00, 0x0A, // Keep Alive
-                               },
+                               })),
                        },
                        want: packet.ConnectVariableHeader{
                                ProtocolName:  "MQTT",
@@ -41,12 +43,12 @@ func TestToConnectVariableHeader(t *testing.T) {
                        name: "固定ヘッダーのPacketTypeが1ではない",
                        args: args{
                                fixedHeader: packet.FixedHeader{PacketType: 2},
-                               bs: []byte{
+                               r: bufio.NewReader(bytes.NewReader([]byte{
                                        0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name
                                        0x04,       // Protocol Level
                                        0xCE,       // Connect Flags
                                        0x00, 0x0A, // Keep Alive
-                               },
+                               })),
                        },
                        want:    packet.ConnectVariableHeader{},
                        wantErr: true,
@@ -55,12 +57,12 @@ func TestToConnectVariableHeader(t *testing.T) {
                        name: "Protocol Nameが不正",
                        args: args{
                                fixedHeader: packet.FixedHeader{PacketType: 1},
-                               bs: []byte{
+                               r: bufio.NewReader(bytes.NewReader([]byte{
                                        0x00, 0x04, 'M', 'Q', 'T', 't', // Protocol Name
                                        0x04,       // Protocol Level
                                        0xCE,       // Connect Flags
                                        0x00, 0x0A, // Keep Alive
-                               },
+                               })),
                        },
                        want:    packet.ConnectVariableHeader{},
                        wantErr: true,
@@ -69,12 +71,12 @@ func TestToConnectVariableHeader(t *testing.T) {
                        name: "Protocol Levelが不正",
                        args: args{
                                fixedHeader: packet.FixedHeader{PacketType: 1},
-                               bs: []byte{
+                               r: bufio.NewReader(bytes.NewReader([]byte{
                                        0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name
                                        0x03,       // Protocol Level
                                        0xCE,       // Connect Flags
                                        0x00, 0x0A, // Keep Alive
-                               },
+                               })),
                        },
                        want:    packet.ConnectVariableHeader{},
                        wantErr: true,
@@ -82,7 +84,7 @@ func TestToConnectVariableHeader(t *testing.T) {
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       got, err := packet.ToConnectVariableHeader(tt.args.fixedHeader, tt.args.bs)
+                       got, err := packet.ToConnectVariableHeader(tt.args.fixedHeader, tt.args.r)
                        if (err != nil) != tt.wantErr {
                                t.Errorf("ToConnectVariableHeader() error = %v, wantErr %v", err, tt.wantErr)
                                return
--- a/mqtt/packet/fixed_header.go
+++ b/mqtt/packet/fixed_header.go
@@ -1,6 +1,8 @@
 package packet
 
-import "github.com/pkg/errors"
+import (
+       "bufio"
+)
 
 type FixedHeader struct {
        PacketType      byte
@@ -20,17 +22,20 @@ func (h FixedHeader) ToBytes() []byte {
        return result
 }
 
-func ToFixedHeader(bs []byte) (FixedHeader, error) {
-       if len(bs) <= 1 {
-               return FixedHeader{}, errors.New("len(bs) should be greater than 1")
+func ToFixedHeader(r *bufio.Reader) (FixedHeader, error) {
+       b, err := r.ReadByte()
+       if err != nil {
+               return FixedHeader{}, err
        }
-       b := bs[0]
        packetType := b >> 4
-       dup := refbit(bs[0], 3)
-       qos1 := refbit(bs[0], 2)
-       qos2 := refbit(bs[0], 1)
-       retain := refbit(bs[0], 0)
-       remainingLength := decodeRemainingLength(bs[1:])
+       dup := refbit(b, 3)
+       qos1 := refbit(b, 2)
+       qos2 := refbit(b, 1)
+       retain := refbit(b, 0)
+       remainingLength, err := decodeRemainingLength(r)
+       if err != nil {
+               return FixedHeader{}, err
+       }
        return FixedHeader{
                PacketType:      packetType,
                Dup:             dup,
@@ -46,12 +51,15 @@ func refbit(b byte, n uint) byte {
 }
 
-func decodeRemainingLength(bs []byte) uint {
+func decodeRemainingLength(r *bufio.Reader) (uint, error) {
        multiplier := uint(1)
        var value uint
        i := uint(0)
        for ; i < 8; i++ {
-               b := bs[i]
+               b, err := r.ReadByte()
+               if err != nil {
+                       return 0, err
+               }
                digit := b
                value = value + uint(digit&127)*multiplier
                multiplier = multiplier * 128
@@ -59,7 +67,7 @@ func decodeRemainingLength(bs []byte) uint {
                        break
                }
        }
-       return value
+       return value, nil
 }
 
 func encodeRemainingLength(x uint) []byte {
--- a/mqtt/packet/fixed_header_test.go
+++ b/mqtt/packet/fixed_header_test.go
@@ -1,7 +1,8 @@
 package packet_test
 
 import (
-       "fmt"
+       "bufio"
+       "bytes"
        "reflect"
        "testing"
 
@@ -10,51 +11,57 @@ import (
 
 func TestToFixedHeader(t *testing.T) {
        type args struct {
-               bs []byte
+               r *bufio.Reader
        }
        tests := []struct {
+               name    string
                args    args
                want    packet.FixedHeader
                wantErr bool
        }{
                {
-                       args: args{[]byte{
+                       name: "[0x00,0x00]",
+                       args: args{bufio.NewReader(bytes.NewBuffer([]byte{
                                0x00, // 0000 0 00 0
                                0x00, // 0
-                       }},
+                       }))},
                        want:    packet.FixedHeader{PacketType: 0, Dup: 0, QoS1: 0, QoS2: 0, Retain: 0, RemainingLength: 0},
                        wantErr: false,
                },
                {
-                       args: args{[]byte{
+                       name: "[0x1b,0x7F]",
+                       args: args{bufio.NewReader(bytes.NewBuffer([]byte{
                                0x1B, // 0001 1 01 1
                                0x7F, // 127
-                       }},
+                       }))},
                        want:    packet.FixedHeader{PacketType: 1, Dup: 1, QoS1: 0, QoS2: 1, Retain: 1, RemainingLength: 127},
                        wantErr: false,
                },
                {
-                       args: args{[]byte{
+                       name: "[0x24,0x80,0x01]",
+                       args: args{bufio.NewReader(bytes.NewBuffer([]byte{
                                0x24,       // 0002 0 10 0
                                0x80, 0x01, //128
-                       }},
+                       }))},
                        want:    packet.FixedHeader{PacketType: 2, Dup: 0, QoS1: 1, QoS2: 0, Retain: 0, RemainingLength: 128},
                        wantErr: false,
                },
                {
-                       args:    args{nil},
+                       name:    "[]",
+                       args:    args{bufio.NewReader(bytes.NewBuffer(nil))},
                        want:    packet.FixedHeader{},
                        wantErr: true,
                },
                {
-                       args:    args{[]byte{0x24}},
+                       name:    "[0x24]",
+                       args:    args{bufio.NewReader(bytes.NewBuffer([]byte{0x24}))},
                        want:    packet.FixedHeader{},
                        wantErr: true,
                },
        }
        for _, tt := range tests {
-               t.Run(fmt.Sprintf("%#v", tt.args.bs), func(t *testing.T) {
-                       got, err := packet.ToFixedHeader(tt.args.bs)
+               t.Run(tt.name, func(t *testing.T) {
+                       got, err := packet.ToFixedHeader(tt.args.r)
                        if (err != nil) != tt.wantErr {
                                t.Errorf("ToFixedHeader() error = %v, wantErr %v", err, tt.wantErr)
                                return

MQTTサーバーを実装しながらGoを学ぶ - その4 テストカバレッジ

前回の続きです。

今回は、MQTTのCONNECTパケットのペイロードから。ペイロードをbinaryパッケージを使って実装します。その後、regexパッケージを使って入力チェック処理を書いてテストします。Goではテストカバレッジが簡単に取得できるようだったので、試しにやってみます。

目次。

CONNECTパケットのペイロード

ペイロードには5つのデータが含まれる。

  • Client Identifier
  • Will Topic
  • Will Message
  • UserName
  • Password

この中のClient Identifierは必須。残りの4つは可変ヘッダーのConnect Flagsの値次第で必要かどうかが決まる。Connect Flagsについては後回しにしているので、ペイロードでもClient Identifierだけをとりあえずは実装する。

http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031

Client Identifier

仕様を読むとClient Identifierはこんな感じ。

  • クライアント毎にユニーク
  • クライアントとサーバー間のセッションを維持するために使う
  • ペイロードの先頭
  • Section 1.5.3で定義されたUTF-8エンコーディングされた文字列
  • さらに条件がある
    • 1〜23byte
    • 使える文字は 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ
    • 以下のように書いてあるけど、許可しないでおく
      • The Server MAY allow ClientId’s that contain more than 23 encoded bytes. The Server MAY allow ClientId’s that contain characters not included in the list given above.
      • A Server MAY allow a Client to supply a ClientId that has a length of zero bytes
  • 不正なClient Identifierだった場合、サーバーはクライアントにCONNACKパケット(return codeは0x02)を返して、コネクションを切断する

上記の仕様を実装していく。

binaryパッケージ

Client Identifierはペイロードの先頭で、Section 1.5.3で定義されたUTF-8エンコーディングされた文字列なので、ペイロードの先頭の2バイト分がClient Identifierの長さになる。バイト列と数値の変換は binary パッケージが使える。

binary package - encoding/binary - Go Packages

ビッグエンディアンなのでlength := binary.BigEndian.Uint16(payload[0:2]) とすればClient Identifierの長さが取得できる。

package packet

import (
    "encoding/binary"

    "github.com/pkg/errors"
)

type ConnectPayload struct {
    ClientID string
}

func ToConnectPayload(bs []byte) (ConnectPayload, error) {
    if len(bs) < 3 {
        return ConnectPayload{}, errors.New("payload length is invalid")
    }
    length := binary.BigEndian.Uint16(bs[0:2])
    var clientID string
    if len(bs) < 2+int(length) {
        clientID = string(bs[2:])
    } else {
        clientID = string(bs[2 : 2+length])
    }
    if len(clientID) < 1 || len(clientID) > 23 {
        return ConnectPayload{}, errors.New("ClientID length is invalid")
    }
    return ConnectPayload{ClientID: clientID}, nil
}
package packet

import (
    "reflect"
    "testing"
)

func TestToConnectPayload(t *testing.T) {
    type args struct {
        bs []byte
    }
    tests := []struct {
        name    string
        args    args
        want    ConnectPayload
        wantErr bool
    }{
        {
            name:    "ClientIDが1文字",
            args:    args{[]byte{0x00, 0x01, 'a'}},
            want:    ConnectPayload{ClientID: "a"},
            wantErr: false,
        },
        {
            name:    "ペイロードが0byte",
            args:    args{[]byte{}},
            want:    ConnectPayload{},
            wantErr: true,
        },
        {
            name:    "ClientIDが23文字を超える",
            args:    args{[]byte{0x00, 0x18, '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd'}},
            want:    ConnectPayload{},
            wantErr: true,
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            got, err := ToConnectPayload(tt.args.bs)
            if (err != nil) != tt.wantErr {
                t.Errorf("ToConnectPayload() error = %v, wantErr %v", err, tt.wantErr)
                return
            }
            if !reflect.DeepEqual(got, tt.want) {
                t.Errorf("ToConnectPayload() = %v, want %v", got, tt.want)
            }
        })
    }
}

regexpパッケージ

次は文字の種類をチェック。正規表現を使う。

regexp package - regexp - Go Packages

                        want:    ConnectPayload{},
                        wantErr: true,
                },
+               {
+                       name:    "使えない文字がある",
+                       args:    args{[]byte{0x00, 0x02, '1', '%'}},
+                       want:    ConnectPayload{},
+                       wantErr: true,
+               },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
--- a/study/packet/connect_payload.go
+++ b/study/packet/connect_payload.go
@@ -2,6 +2,7 @@ package packet
 
 import (
        "encoding/binary"
+       "regexp"
 
        "github.com/pkg/errors"
 )
@@ -10,6 +11,8 @@ type ConnectPayload struct {
        ClientID string
 }
 
+var clientIDRegex = regexp.MustCompile("^[a-zA-Z0-9-|]*$")
+
 func ToConnectPayload(bs []byte) (ConnectPayload, error) {
        if len(bs) < 3 {
                return ConnectPayload{}, errors.New("payload length is invalid")
@@ -24,5 +27,8 @@ func ToConnectPayload(bs []byte) (ConnectPayload, error) {
        if len(clientID) < 1 || len(clientID) > 23 {
                return ConnectPayload{}, errors.New("ClientID length is invalid")
        }
+       if !clientIDRegex.MatchString(clientID) {
+               return ConnectPayload{}, errors.New("clientId format shoud be \"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\"")
+       }
        return ConnectPayload{ClientID: clientID}, nil
 }

これでOK!

$ go test ./packet/
ok      github.com/bati11/oreno-mqtt/study/packet       1.283s

テストのカバレッジ

go test でテストを実行してきたが、オプションをつけるとカバレッジを取得できるらしい。しかも -cover オプションをつけるだけ。

The cover story - The Go Programming Language

試してみる。

$ go test -cover ./packet/
ok      github.com/bati11/oreno-mqtt/study/packet       1.270s  coverage: 97.6% of statements

97.6%。さっきの文字種類のテストコードを削ってみる。

                        want:    ConnectPayload{},
                        wantErr: true,
                },
-               {
-                       name:    "使えない文字がある",
-                       args:    args{[]byte{0x00, 0x02, '1', '%'}},
-                       want:    ConnectPayload{},
-                       wantErr: true,
-               },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {

実行。

$ go test -cover ./packet/
ok      github.com/bati11/oreno-mqtt/study/packet       0.655s  coverage: 95.2% of statements

95.2%に下がった。

もっと詳細に調べてみる。 -coverprofile=cover.out という指定をすると、cover.outというファイルができて、そのファイルを go tool cover コマンドでプロファイリングできる。

$ go test -coverprofile=cover.out ./packet/
ok      github.com/bati11/oreno-mqtt/study/packet       0.019s  coverage: 95.2% of statements
$ go tool cover -func=cover.out
github.com/bati11/oreno-mqtt/study/packet/connect_payload.go:16:                ToConnectPayload        83.3%
github.com/bati11/oreno-mqtt/study/packet/connect_variable_header.go:23:        ToConnectVariableHeader 100.0%
github.com/bati11/oreno-mqtt/study/packet/connect_variable_header.go:41:        isValidProtocolName     100.0%
github.com/bati11/oreno-mqtt/study/packet/fixed_header.go:14:                   ToFixedHeader           100.0%
github.com/bati11/oreno-mqtt/study/packet/fixed_header.go:35:                   refbit                  100.0%
github.com/bati11/oreno-mqtt/study/packet/fixed_header.go:40:                   decodeRemainingLength   100.0%
total:                                                                          (statements)            95.2%

ふむふむ、connect_payload.goの ToConnectPayloadカバレッジが低めだということが分かる。

さらにさらに、 -html=cover.out と指定するとブラウザで確認することもできる。

$ go tool cover -html=cover.out

https://i.gyazo.com/5e2b68026888cf2537364dc5f0c2a41d.png

す、すごい...。

他にも -covermode というオプションがあり、実行された回数まで含めてプロファイルもできそう。

カバレッジ見てたら、指定された長さに対して実際に取得できるClient Identifierの長さが足りない場合のテストが足りないことに気がついたので追加。さっき消したテストも戻す。

--- a/study/packet/connect_payload.go
+++ b/study/packet/connect_payload.go
@@ -20,7 +20,7 @@ func ToConnectPayload(bs []byte) (ConnectPayload, error) {
        length := binary.BigEndian.Uint16(bs[0:2])
        var clientID string
        if len(bs) < 2+int(length) {
-               clientID = string(bs[2:])
+               return ConnectPayload{}, errors.New("specified length is not equals ClientID length")
        } else {
                clientID = string(bs[2 : 2+length])
        }
--- a/study/packet/connect_payload_test.go
+++ b/study/packet/connect_payload_test.go
@@ -33,6 +33,18 @@ func TestToConnectPayload(t *testing.T) {
                        want:    ConnectPayload{},
                        wantErr: true,
                },
+               {
+                       name:    "使えない文字がある",
+                       args:    args{[]byte{0x00, 0x02, '1', '%'}},
+                       want:    ConnectPayload{},
+                       wantErr: true,
+               },
+               {
+                       name:    "指定された長さよりも実際に取得できたClientIDが短い",
+                       args:    args{[]byte{0x00, 0x03, '1', '2'}},
+                       want:    ConnectPayload{},
+                       wantErr: true,
+               },
        }
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
go test -coverprofile=cover.out ./packet/
ok      github.com/bati11/oreno-mqtt/study/packet       0.022s  coverage: 100.0% of statements
 go tool cover -func=cover.out
github.com/bati11/oreno-mqtt/study/packet/connect_payload.go:16:                ToConnectPayload        100.0%
github.com/bati11/oreno-mqtt/study/packet/connect_variable_header.go:23:        ToConnectVariableHeader 100.0%
github.com/bati11/oreno-mqtt/study/packet/connect_variable_header.go:41:        isValidProtocolName     100.0%
github.com/bati11/oreno-mqtt/study/packet/fixed_header.go:14:                   ToFixedHeader           100.0%
github.com/bati11/oreno-mqtt/study/packet/fixed_header.go:35:                   refbit                  100.0%
github.com/bati11/oreno-mqtt/study/packet/fixed_header.go:40:                   decodeRemainingLength   100.0%
total:                                                                          (statements)            100.0%

$ go test --help を読むと以下のように書いてある。

-cover
    Enable coverage analysis.
    Note that because coverage works by annotating the source
    code before compilation, compilation and test failures with
    coverage enabled may report line numbers that don't correspond
    to the original sources.

-cover をつけない状態でテストして、PASSしてからカバレッジを取得する方が良さそうだ。

おしまい

MQTTのCONNECTパケットのペイロードを実装して、テストのカバレッジを取得しました。次回はサーバーとして起動するところまでいきたい。

今回の学び