SQLアンチパターン・ジェイウォークのクエリをシェルでやる

SQLアンチパターンという本があります。 その本の1章がジェイウォーク(信号無視)。ジェイウォークで紹介されているようなデータがtsvファイルとして手元にある場合に、SQLではなくシェルでなんとかするお話です。

SQLアンチパターン

SQLアンチパターン

試しに使ったMySQLのバージョンは5.7です。

ジェイウォーク

製品テーブルとアカウントテーブルがあり、製品ごとに複数人の担当者(アカウント)がいる、とする

  CREATE TABLE `accounts` (
    `account_id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
    `account_name` varchar(20) COLLATE utf8mb4_bin DEFAULT NULL,
    PRIMARY KEY (`account_id`),
    UNIQUE KEY `account_id` (`account_id`)
  )
  
  CREATE TABLE `products` (
    `product_id` bigint(20) unsigned NOT NULL AUTO_INCREMENT,
    `product_name` varchar(1000) COLLATE utf8mb4_bin DEFAULT NULL,
    `account_id` varchar(100) COLLATE utf8mb4_bin DEFAULT NULL,
    PRIMARY KEY (`product_id`),
    UNIQUE KEY `product_id` (`product_id`)
  )
  mysql> SELECT * FROM products;
  +------------+---------------------+------------+
  | product_id | product_name        | account_id |
  +------------+---------------------+------------+
  |          1 | Visual TurboBuilder | 12,34      |
  |          2 | hoge fuga           | 555,666    |
  +------------+---------------------+------------+
  
  mysql> SELECT * FROM accounts;
  +------------+--------------+
  | account_id | account_name |
  +------------+--------------+
  |         12 | taro         |
  |         34 | hanako       |
  |        555 | goro         |
  +------------+--------------+

製品テーブルのaccount_idカラムに、複数のアカウントIDをカンマ区切りの文字列として保持するアンチパターン。

tsvファイル

同じようなtsvファイルがあるとする。

$ cat products.tsv
product_id  product_name    account_id
1  Visual TurboBuilder 12,34
2  hoge fuga   555,666,777

$ cat accounts.tsv
account_id  account_name
12 taro
34 hanako
555    goro

SQL vs シェル

特定のアカウントに関連する製品の検索

account_id=12に関連するproductsを検索する。

SQLアンチパターンで紹介されているSQLは以下。正規表現を使った WHERE 句の指定がなかなか強烈...。

SELECT * FROM products;
+------------+---------------------+-------------+
| product_id | product_name        | account_id  |
+------------+---------------------+-------------+
|          1 | Visual TurboBuilder | 12,34       |
|          2 | hoge fuga           | 555,666,777 |
+------------+---------------------+-------------+
SELECT * FROM products WHERE account_id REGEXP '[[:<:]]12[[:>:]]';
+------------+---------------------+------------+
| product_id | product_name        | account_id |
+------------+---------------------+------------+
|          1 | Visual TurboBuilder | 12,34      |
+------------+---------------------+------------+

対して、シェルでやってみる。 NR で行番号を取得してヘッダー行かそれ以外の行かで分岐、ヘッダー行以外の場合は $3 (3カラム目の値)と正規表現を使って行を絞り込む。

$ cat products.tsv
product_id  product_name    account_id
1   Visual TurboBuilder 12,34
2   hoge fuga   555,666,777
$ cat products.tsv | awk -F '\t' '{
>     if (NR == 1) { print }         #ヘッダー行を出力
>     else if ($3 ~ /12/) { print }  #ヘッダー行以外は account_id=12 で絞り込み
> }'
product_id  product_name    account_id
1   Visual TurboBuilder 12,34

ヘッダー行の出力をなんとかして if をなくしたいなぁ、いい方法ないだろうか?

特定の製品に関連するアカウントの検索

product_id=1に関連するアカウント情報の一覧を取得する。

SQLアンチパターンで紹介されているSQLは以下。

SELECT * FROM products;
+------------+---------------------+-------------+
| product_id | product_name        | account_id  |
+------------+---------------------+-------------+
|          1 | Visual TurboBuilder | 12,34       |
|          2 | hoge fuga           | 555,666,777 |
+------------+---------------------+-------------+

SELECT * FROM accounts;
+------------+--------------+
| account_id | account_name |
+------------+--------------+
|         12 | taro         |
|         34 | hanako       |
|        555 | goro         |
+------------+--------------+
SELECT *
  FROM products AS p
 INNER JOIN accounts AS a
         ON p.account_id REGEXP CONCAT('[[:<:]]', a.account_id, '[[:>:]]')
 WHERE p.product_id = 1;
+------------+---------------------+------------+------------+--------------+
| product_id | product_name        | account_id | account_id | account_name |
+------------+---------------------+------------+------------+--------------+
|          1 | Visual TurboBuilder | 12,34      |         12 | taro         |
|          1 | Visual TurboBuilder | 12,34      |         34 | hanako       |
+------------+---------------------+------------+------------+--------------+

シェルでやってみる。

$ cat products.tsv
product_id  product_name    account_id
1   Visual TurboBuilder 12,34
2   hoge fuga   555,666,777

$ cat accounts.tsv
account_id  account_name
12  taro
34  hanako
555 goro

ヘッダー行の取り扱いはさっきと同じ。 split を使ってカンマ区切りの文字列を配列にセット、配列に対するループ内で print することで、account_id毎の行に展開する。結果は、tmp_product.tsvというファイルに書き込んでおく。

$ cat products.tsv | awk -F '\t' '{
>     if (NR == 1) { print }                 #ヘッダー行を出力
>     else if ($1 == 1) {                    #ヘッダー行以外は product_id=1 で絞り込み
>         split($3, arr, ",");               #splitでカンマで分割
>         for (i in arr) {
>             print $1 "\t" $2 "\t" arr[i]   #複数行に展開
>         }
>     }
> }' > tmp_product.tsv

$ cat tmp_product.tsv
product_id  product_name    account_id
1   Visual TurboBuilder 12
1   Visual TurboBuilder 34

作成したtmp_product.tsvとaccounts.tsvとを、account_id列でjoinする。タブ文字区切りにしたいので -t オプションを使うが "\t" という指定をしてもタブ文字として扱ってくれないので、bashの ${string}の記法を使って指定する(この記法は何か呼び名はあるのかな?)。

$ join --header -t $'\t' -1 3 -2 1 tmp_product.tsv accounts.tsv 
account_id  product_id  product_name    account_name
12  1   Visual TurboBuilder taro
34  1   Visual TurboBuilder hanako

できた!

ワンライナーで書きたい場合は、bashだとプロセス置換を使ってjoinコマンドで読む。

$ join --header -t $'\t' -1 3 -2 1 <(cat products.tsv | awk -F '\t' '{ if (NR == 1) { print } else if ($1 == 1) { split($3, arr, ","); for (i in arr) { print $1 "\t" $2 "\t" arr[i] }}}') accounts.tsv 
account_id  product_id  product_name    account_name
12  1   Visual TurboBuilder taro
34  1   Visual TurboBuilder hanako

集約クエリ

product毎に関連するaccountの数を取得する。

SQLアンチパターンで紹介されてるクエリは以下。文字列長からカンマ以外の文字列長を引くことでカンマの数を取得して +1 する。これもなかなか強烈。

SELECT product_id, LENGTH(account_id) - LENGTH(REPLACE(account_id, ',', '')) + 1 AS contracts_per_product
  FROM products;
+------------+-----------------------+
| product_id | contracts_per_product |
+------------+-----------------------+
|          1 |                     2 |
|          2 |                     3 |
+------------+-----------------------+

これはシェルなら簡単。awkのsplitが配列の要素の数を返してくれるので、それを出力すればOK!

$ cat products.tsv | awk -F '\t' '
> BEGIN { print "product_id" "\t" "contracts_per_product"} #BEGINでヘッダー行を出力する
> NR > 1 {
>     n = split($3, arr, ",");
>     print $1 "\t" n
> }'
product_id  contracts_per_product
1   2
2   3

おしまい

シェルはおもしろいですね!

MQTTサーバーを実装しながらGoを学ぶ - その12 Contextを使ったgoroutineの停止

前回、別goroutineで発生したエラーハンドリングをしました。具体的にはサブスクライバへの書き込みでエラーが発生した場合にサブスクリプションの削除処理するようにしました。今回は、 context.Context を使ってgoroutineを停止することで、クライアントが接続を切ったタイミングでサブスクリプションの削除をするようにしてみます。

目次。

今のgoroutine

現状の実装だと、クライアントからSUBSCRIBEパケットが送信されてくると、いくつかのgoroutineが生成され以下のような親子関係になる。

https://i.gyazo.com/6bf617a5406e91b099915ce86e0aff09.png

  • main
    • サーバーのメインgoroutine
    • Run 関数内の conn, err := ln.Accept() でブロックしてる
  • Broker
    • "main" goroutineによって1つだけ生成される
    • パブリッシャからのメッセージをサブスクリプションへ配送するgoroutine
    • Broker 関数内の for ... select ... で無限ループしてる
  • handle
    • "main" goroutineによってクライアントからのTCP接続がある度に生成される
    • MQTTパケットを受け取り処理する
    • handle 関数内から mqttReader.ReadPacketType() の呼び出し、最終的には net.Conn に対する Read でブロックしてる
  • handleSub
    • "handle" goroutineによってクライアントからのSUBSCRIBEパケットにより生成される
    • Broker からchannel経由でPUBLISHを受け取り、 net.Conn に書き込む
    • channelの読み取りでブロックしてる
  • handle内の無名関数
    • "handle" goroutineによってクライアントからのSUBSCRIBEパケットにより生成される
    • handleSub 関数で発生した error をerror channel経由で読み取り Broker に伝える
    • channelの読み取りでブロックしてる

net.Conn に対する Read をしているのは"handle" goroutineなので、サブスクライバが切断したことは"handle" goroutineで処理できるはず。

goroutineリーク

サブスクライバが切断、つまりmosquitto_subをCtrl-Cで終了したとき、handleSubのgoroutineが減らない。

2つmosquitto_subを実行してる状態でpprofを使いgoroutieの状態を見てみる。pprofは導入済みなので http://localhost:6060/debug/pprof/goroutine?debug=1 にアクセスすれば良い。

結果の一部は以下。

2 @ 0x102f20b 0x102a6a9 0x1029d56 0x108e22a 0x108e33d 0x108f0e6 0x118011f 0x11926f8 0x11330ef 0x1133989 0x13239b7 0x1327a6b 0x1328457 0x105ca31
#   0x1029d55   internal/poll.runtime_pollWait+0x65                     /Users/bati11/.goenv/versions/1.11.4/src/runtime/netpoll.go:173
#   0x108e229   internal/poll.(*pollDesc).wait+0x99                     /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_poll_runtime.go:85
#   0x108e33c   internal/poll.(*pollDesc).waitRead+0x3c                     /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_poll_runtime.go:90
#   0x108f0e5   internal/poll.(*FD).Read+0x1d5                          /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_unix.go:169
#   0x118011e   net.(*netFD).Read+0x4e                              /Users/bati11/.goenv/versions/1.11.4/src/net/fd_unix.go:202
#   0x11926f7   net.(*conn).Read+0x67                               /Users/bati11/.goenv/versions/1.11.4/src/net/net.go:177
#   0x11330ee   bufio.(*Reader).fill+0x10e                          /Users/bati11/.goenv/versions/1.11.4/src/bufio/bufio.go:100
#   0x1133988   bufio.(*Reader).ReadByte+0x38                           /Users/bati11/.goenv/versions/1.11.4/src/bufio/bufio.go:242
#   0x13239b6   github.com/bati11/oreno-mqtt/mqtt/packet.(*MQTTReader).ReadPacketType+0x76  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/packet/mqtt_reader.go:20
#   0x1327a6a   github.com/bati11/oreno-mqtt/mqtt.handle+0x13a                  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:54
#   0x1328456   github.com/bati11/oreno-mqtt/mqtt.Run.func1+0x56                /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:38

2 @ 0x102f20b 0x102f2b3 0x100758e 0x10072bb 0x1328501 0x105ca31
#   0x1328500   github.com/bati11/oreno-mqtt/mqtt.handle.func1+0x40 /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:91

2 @ 0x102f20b 0x102f2b3 0x100758e 0x10072bb 0x132865c 0x105ca31
#   0x132865b   github.com/bati11/oreno-mqtt/mqtt.handleSub.func1+0x6b  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:118

3つのブロックがあるが、最初のブロックは packet.(*MQTTReader).ReadPacketType+0x76 とあり mqttReader.ReadPacketType() の呼び出しのことなので "handle" goroutineである。先頭の 2 という数値から "handle" goroutineが2つ存在してることになる。

2つめのブロックは、"handle" goroutineの数は 1 に減っているが、"handle 関数内の無名関数" goroutine、3つめのブロックは "handleSub" goroutineである。2つサブスクライバがいるので、goroutineの数もそれぞれ2つである。

1つCtrl-Cで終了してから再度状態を確認。"handle内の無名関数" goroutineと"handleSub" goroutineが減ってない。

goroutine profile: total 10
2 @ 0x102f20b 0x102f2b3 0x100758e 0x10072bb 0x1328501 0x105ca31
#   0x1328500   github.com/bati11/oreno-mqtt/mqtt.handle.func1+0x40 /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:91

2 @ 0x102f20b 0x102f2b3 0x100758e 0x10072bb 0x132865c 0x105ca31
#   0x132865b   github.com/bati11/oreno-mqtt/mqtt.handleSub.func1+0x6b  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:118

1 @ 0x102f20b 0x102a6a9 0x1029d56 0x108e22a 0x108e33d 0x108f0e6 0x118011f 0x11926f8 0x11330ef 0x1133989 0x13239b7 0x1327a6b 0x1328457 0x105ca31
#   0x1029d55   internal/poll.runtime_pollWait+0x65                     /Users/bati11/.goenv/versions/1.11.4/src/runtime/netpoll.go:173
#   0x108e229   internal/poll.(*pollDesc).wait+0x99                     /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_poll_runtime.go:85
#   0x108e33c   internal/poll.(*pollDesc).waitRead+0x3c                     /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_poll_runtime.go:90
#   0x108f0e5   internal/poll.(*FD).Read+0x1d5                          /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_unix.go:169
#   0x118011e   net.(*netFD).Read+0x4e                              /Users/bati11/.goenv/versions/1.11.4/src/net/fd_unix.go:202
#   0x11926f7   net.(*conn).Read+0x67                               /Users/bati11/.goenv/versions/1.11.4/src/net/net.go:177
#   0x11330ee   bufio.(*Reader).fill+0x10e                          /Users/bati11/.goenv/versions/1.11.4/src/bufio/bufio.go:100
#   0x1133988   bufio.(*Reader).ReadByte+0x38                           /Users/bati11/.goenv/versions/1.11.4/src/bufio/bufio.go:242
#   0x13239b6   github.com/bati11/oreno-mqtt/mqtt/packet.(*MQTTReader).ReadPacketType+0x76  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/packet/mqtt_reader.go:20
#   0x1327a6a   github.com/bati11/oreno-mqtt/mqtt.handle+0x13a                  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:54
#   0x1328456   github.com/bati11/oreno-mqtt/mqtt.Run.func1+0x56                /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:38

goroutineが残り続けてしまうことを、goroutineリークというらしい。

"handleSub" goroutineの親goroutineである "handle" goroutineが終了するときに、子goroutineを終了させたい。

おなじみの「Go言語による並行処理」に以下のように書いてある。

Go言語による並行処理

Go言語による並行処理

  • 作者: Katherine Cox-Buday,山口能迪
  • 出版社/メーカー: オライリージャパン
  • 発売日: 2018/10/26
  • メディア: 単行本(ソフトカバー)
  • この商品を含むブログを見る

4.3 ゴルーチンリークを避ける

もしあるゴルーチンがゴルーチンの生成の責任を持っているのであれば、そのゴルーチンを停止できるようにする責任もあります。

書籍では done チャネルを使った方式と、それがGo1.7からは context パッケージとして標準化されたことが書いてある。

contextパッケージ

context.Contextを使う。

context.Contextを使うことで、親goroutineから子goroutineを停止することができる。

diff --git a/mqtt/server.go b/mqtt/server.go
index 5f53e6f..a050e0a 100644
--- a/mqtt/server.go
+++ b/mqtt/server.go
@@ -2,6 +2,7 @@ package mqtt
 
 import (
        "bufio"
+       "context"
        "fmt"
        "io"
        "net"
@@ -48,6 +49,9 @@ func handle(conn net.Conn, publishToBroker chan<- *packet.Publish, subscriptionT
 
        var clientID string
 
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
        for {
                r := bufio.NewReader(conn)
                mqttReader := packet.NewMQTTReader(r)
@@ -85,16 +89,21 @@ func handle(conn net.Conn, publishToBroker chan<- *packet.Publish, subscriptionT
                        if err != nil {
                                return err
                        }
-                       subscription, errCh := handleSub(clientID, conn)
+                       subscription, errCh := handleSub(ctx, clientID, conn)
                        subscriptionToBroker <- subscription
-                       go func() {
-                               err, ok := <-errCh
-                               if !ok {
-                                       return
+                       go func(ctx context.Context) {
+                               var result *DoneSubscriptionResult
+                               select {
+                               case <-ctx.Done():
+                                       result = NewDoneSubscriptionResult(subscription.clientID, nil)
+                               case err, ok := <-errCh:
+                                       if !ok {
+                                               return
+                                       }
+                                       result = NewDoneSubscriptionResult(subscription.clientID, err)
                                }
-                               done := NewDoneSubscriptionResult(subscription.clientID, err)
-                               doneSubscriptions <- done
-                       }()
+                               doneSubscriptions <- result
+                       }(ctx)
                case packet.PINGREQ:
                        pingresp, err := handler.HandlePingreq(mqttReader)
                        if err != nil {
@@ -110,16 +119,24 @@ func handle(conn net.Conn, publishToBroker chan<- *packet.Publish, subscriptionT
        }
 }
 
-func handleSub(clientID string, conn net.Conn) (*Subscription, <-chan error) {
+func handleSub(ctx context.Context, clientID string, conn net.Conn) (*Subscription, <-chan error) {
        errCh := make(chan error)
        subscription, pubFromBroker := NewSubscription(clientID)
        go func() {
                defer close(errCh)
-               for publishMessage := range pubFromBroker {
-                       bs := publishMessage.ToBytes()
-                       _, err := conn.Write(bs)
-                       if err != nil {
-                               errCh <- err
+               for {
+                       select {
+                       case <-ctx.Done():
+                               return
+                       case publishedMessage, ok := <-pubFromBroker:
+                               if !ok {
+                                       return
+                               }
+                               bs := publishedMessage.ToBytes()
+                               _, err := conn.Write(bs)
+                               if err != nil {
+                                       errCh <- err
+                               }
                        }
                }
        }()

試してみる。

2つsub

goroutine profile: total 12
2 @ 0x102f20b 0x102a6a9 0x1029d56 0x108e22a 0x108e33d 0x108f0e6 0x118011f 0x11926f8 0x11330ef 0x1133989 0x13239b7 0x1327abc 0x1328527 0x105ca31
#   0x1029d55   internal/poll.runtime_pollWait+0x65                     /Users/bati11/.goenv/versions/1.11.4/src/runtime/netpoll.go:173
#   0x108e229   internal/poll.(*pollDesc).wait+0x99                     /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_poll_runtime.go:85
#   0x108e33c   internal/poll.(*pollDesc).waitRead+0x3c                     /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_poll_runtime.go:90
#   0x108f0e5   internal/poll.(*FD).Read+0x1d5                          /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_unix.go:169
#   0x118011e   net.(*netFD).Read+0x4e                              /Users/bati11/.goenv/versions/1.11.4/src/net/fd_unix.go:202
#   0x11926f7   net.(*conn).Read+0x67                               /Users/bati11/.goenv/versions/1.11.4/src/net/net.go:177
#   0x11330ee   bufio.(*Reader).fill+0x10e                          /Users/bati11/.goenv/versions/1.11.4/src/bufio/bufio.go:100
#   0x1133988   bufio.(*Reader).ReadByte+0x38                           /Users/bati11/.goenv/versions/1.11.4/src/bufio/bufio.go:242
#   0x13239b6   github.com/bati11/oreno-mqtt/mqtt/packet.(*MQTTReader).ReadPacketType+0x76  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/packet/mqtt_reader.go:20
#   0x1327abb   github.com/bati11/oreno-mqtt/mqtt.handle+0x18b                  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:58
#   0x1328526   github.com/bati11/oreno-mqtt/mqtt.Run.func1+0x56                /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:39

2 @ 0x102f20b 0x103ef16 0x132865f 0x105ca31
#   0x132865e   github.com/bati11/oreno-mqtt/mqtt.handle.func1+0xce /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:95

2 @ 0x102f20b 0x103ef16 0x1328890 0x105ca31
#   0x132888f   github.com/bati11/oreno-mqtt/mqtt.handleSub.func1+0xff  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:127

1つをCtrl-Cで閉じる

goroutine profile: total 8
1 @ 0x102f20b 0x102a6a9 0x1029d56 0x108e22a 0x108e33d 0x108f0e6 0x118011f 0x11926f8 0x11330ef 0x1133989 0x13239b7 0x1327abc 0x1328527 0x105ca31
#   0x1029d55   internal/poll.runtime_pollWait+0x65                     /Users/bati11/.goenv/versions/1.11.4/src/runtime/netpoll.go:173
#   0x108e229   internal/poll.(*pollDesc).wait+0x99                     /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_poll_runtime.go:85
#   0x108e33c   internal/poll.(*pollDesc).waitRead+0x3c                     /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_poll_runtime.go:90
#   0x108f0e5   internal/poll.(*FD).Read+0x1d5                          /Users/bati11/.goenv/versions/1.11.4/src/internal/poll/fd_unix.go:169
#   0x118011e   net.(*netFD).Read+0x4e                              /Users/bati11/.goenv/versions/1.11.4/src/net/fd_unix.go:202
#   0x11926f7   net.(*conn).Read+0x67                               /Users/bati11/.goenv/versions/1.11.4/src/net/net.go:177
#   0x11330ee   bufio.(*Reader).fill+0x10e                          /Users/bati11/.goenv/versions/1.11.4/src/bufio/bufio.go:100
#   0x1133988   bufio.(*Reader).ReadByte+0x38                           /Users/bati11/.goenv/versions/1.11.4/src/bufio/bufio.go:242
#   0x13239b6   github.com/bati11/oreno-mqtt/mqtt/packet.(*MQTTReader).ReadPacketType+0x76  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/packet/mqtt_reader.go:20
#   0x1327abb   github.com/bati11/oreno-mqtt/mqtt.handle+0x18b                  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:58
#   0x1328526   github.com/bati11/oreno-mqtt/mqtt.Run.func1+0x56                /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:39

1 @ 0x102f20b 0x103ef16 0x132865f 0x105ca31
#   0x132865e   github.com/bati11/oreno-mqtt/mqtt.handle.func1+0xce /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:95

1 @ 0x102f20b 0x103ef16 0x1328890 0x105ca31
#   0x132888f   github.com/bati11/oreno-mqtt/mqtt.handleSub.func1+0xff  /Users/bati11/dev/src/github.com/bati11/oreno-mqtt/mqtt/server.go:127

ちゃんと"handle内の無名関数" goroutineと"handleSub" goroutineが減ってる。

おしまい

context.Context を使って、子goroutineを停止させることができました。 ctx をどんどん渡していけば孫goroutineやひ孫goroutineなども同じように停止させることができます。Goでは複数の並行処理を協調させる仕組みが色々あって良いですね!

コードはこちら。

github.com

今回の学び。

MQTTサーバーを実装しながらGoを学ぶ - その11 goroutineのエラーハンドリング, map, goroutineセーフ

前回の続き。別のgoroutineで発生したエラーをerror channelを使ってハンドリングしてみたいと思います。ハンドリングの処理でmapを使ったのですが、goroutineセーフにするため sync.Map を使ってみました。

目次。

handleSub でエラー

handleSub 関数でサブスクライバに対してPUBLISHを送っている。

func handleSub(conn net.Conn, fromBroker <-chan *packet.Publish) {
    for publishMessage := range fromBroker {
        bs := publishMessage.ToBytes()
        _, err := conn.Write(bs)
        if err != nil {
            // FIXME
            fmt.Println(err)
        }
    }
}

conn.Write でエラーだった場合のことを考える。例えば、mosquitto_subでサブスクライブした後、すぐにCtrl-Cで切断しておく。その後にmosquitto_pubでパブリッシュすると conn.Writeerr を返し、サーバーが以下を出力する。

write tcp 127.0.0.1:1883->127.0.0.1:59452: use of closed network connection

Broker がサブスクライバが切断したことに気がつかず、サブスクリプションを削除せずにPUBLISHされたメッセージを配送し続けてしまう。 handleSub goroutineで発生したエラーを Broker goroutineに伝えて、 Broker が不要なサブスクリプションを削除する方法を考える。

別goroutineを受け取るerror channel

PUBLISHメッセージを配送する Broker -> handlerSub という方向のchannelとは別に、 errorhandleSub -> Broker という方向に流すchannelを用意すれば良い。

現在の Broker は以下。

// broker.go

func Broker(fromPub <-chan *packet.Publish, subscriptions <-chan *Subscription, doneSubscriptions <-chan *DoneSubscription) {
    sMap := newSubscriptionMap()
    for {
        select {
        case sub := <-subscriptions:
            sMap.put(sub.clientID, sub)
        case message := <-fromPub:
            sMap.apply(func(sub *Subscription) {
                sub.pubToSub <- message
            })
        }
    }
}

以下のように selectdoneSubscriptions channelも読むようにする。

func Broker(fromPub <-chan *packet.Publish, subscriptions <-chan *Subscription, doneSubscriptions <-chan error) {
    sMap := newSubscriptionMap()
    for {
        select {
        case sub := <-subscriptions:
            sMap.put(sub.clientID, sub)
        case message := <-fromPub:
            sMap.apply(func(sub *Subscription) {
                sub.pubToSub <- message
            })
+       case err := <-doneSubscriptions:
+           fmt.Println(err)
+       }
    }
}

doneSubscriptions channelを使って handleSub goroutine から Broker goroutineに error を流す。

server.goでは3つのchannelを生成することになる。

func Run() {
    ln, err := net.Listen("tcp", "localhost:1883")
    if err != nil {
        panic(err)
    }
    defer ln.Close()

    pub := make(chan *packet.Publish)
    defer close(pub)
    subscriptions := make(chan *Subscription)
    defer close(subscriptions)
+   doneSubscriptions := make(chan *DoneSubscriptionResult)
+   defer close(doneOfSubscription)

    go Broker(pub, subscriptions, doneSubscriptions)

    for {
        conn, err := ln.Accept()
        if err != nil {
            panic(err)
        }

        go func() {
-          err = handle(conn, pub, subscriptions)
+           err = handle(conn, pub, subscriptions, doneSubscriptions)
            if err != nil {
                panic(err)
            }
        }()
    }
}

doneSubscriptions channelが 「Run 関数のgoroutine」 -> 「handle関数を実行するgoroutine」 -> 「handleSub 関数のgoroutine」と渡されていくことになる。

つまり書き込み可能なchannelが孫goroutineまで渡されることになる。channelの所有者がはっきりしなくなってきた...。

書き込み可能なchannelを関数に渡すのではなく、関数内でchannelとgoroutineを生成し、読み取り専用channelを返すようにする。

具体的には、 handleSub 関数を以下のようにする。

// server.go

func handleSub(conn net.Conn, fromBroker <-chan *packet.Publish) <-chan error {
    errCh := make(chan error)
    go func() {
        defer close(errCh)
        for publishMessage := range fromBroker {
            bs := publishMessage.ToBytes()
            _, err := conn.Write(bs)
            if err != nil {
                errCh <- err
            }
        }
    }()
    return errCh
}

errCh channelのライフサイクルが handleSubscription 関数内に閉じている。関数内でchannelを生成、関数内でgoroutineを起動し生成したchanelの所有権を移譲する(goroutine内で書き込みと close をする)。そして関数の返り値は読み取り専用channelとする。

このように1つの関数内にchannelのライフサイクルを閉じ込める書き方が「Go言語による並行処理パターン」では頻出してる。

この handleSub 関数の書き換えによって、 Run 関数で生成した doneSubscriptions channelを孫goroutine( handleSub goroutine)まで渡さなくて良くなる。が、 handleSub 関数から返ってきた読み取り専用channelの読み込みでブロックするため新しくgoroutineを作る必要がある。

       case packet.SUBSCRIBE:
            suback, err := handler.HandleSubscribe(mqttReader)
            if err != nil {
                return err
            }
            _, err = conn.Write(suback.ToBytes())
                        if err != nil {
                                return err
                        }
-                       sub := make(chan *packet.Publish)
-                       subscriptionToBroker <- sub
-                       go handleSub(conn, sub)
+                       subscription, errCh := handleSub(clientID, conn)
+                       subscriptionToBroker <- subscription
+                       go func() {
+                               err, ok := <-errCh
+                               if !ok {
+                                       return
+                               }
+                               doneSubscriptions <- err
+                       }()

サブスクリプションの特定

MQTTのClient ID

これで errorhandleSub goroutineから Broker goroutineへ渡すことができそう。 error を受け取った Broker はどのSubscriptionを閉じるかをどうやって決めれば良いだろう?

今まで無視してたMQTTのClient IDを使うことにする。Client IDはCONNECT時にペイロードに含まれていた。

CONNECTパケットのハンドリング時に、Client IDを変数にセットしておく。

func handle(conn net.Conn, publishToBroker chan<- *packet.Publish, subscriptionToBroker chan<- Subscription) error {
        defer conn.Close()
 
+       var clientID string
+
        for {
                r := bufio.NewReader(conn)
                mqttReader := packet.NewMQTTReader(r)
                case packet.CONNECT:
-                       connack, err := handler.HandleConnect(mqttReader)
+                       connect, connack, err := handler.HandleConnect(mqttReader)
                        if err != nil {
                                return err
                        }
                        _, err = conn.Write(connack.ToBytes())
                        if err != nil {
                                return err
                        }
+                       clientID = connect.Payload.ClientID
                case packet.SUBSCRIBE:
                        suback, err := handler.HandleSubscribe(mqttReader)
                        if err != nil {

現在の実装だと以下のように読み取り専用channelを Subscription としている。

type Subscription chan<- *packet.Publish

これをstructにしてClient IDを持たせるようにする。さらに、 handleSub goroutineから Broker goroutineへ伝える error もsturctで包んでClient IDを関連づける。

// broker.go

type Subscription struct {
    clientID string
    pubToSub chan<- *packet.Publish
}

func NewSubscription(clientID string) (*Subscription, <-chan *packet.Publish) {
    pub := make(chan *packet.Publish)
    s := &Subscription{
        clientID: clientID,
        pubToSub: pub,
    }
    return s, pub
}

type DoneSubscriptionResult struct {
    clientID string
    err      error
}

func NewDoneSubscriptionResult(clientID string, err error) *DoneSubscriptionResult {
    return &DoneSubscriptionResult{clientID, err}
}

func Broker(fromPub <-chan *packet.Publish, subscriptions <-chan *Subscription, doneSubscriptions <-chan *DoneSubscriptionResult) {
    ...
}

map

現在の Broker の実装では Subscription を配列で管理してる。

func Broker(fromPub <-chan *packet.Publish, subscriptions <-chan *Subscription, doneSubscriptions <-chan *DoneSubscriptionResult) {
    // サブスクリプションの配列
    var ss []Subscription
    for {
        select {
        case sub := <- subscriptions:
            // channelからサブスクリプションを読み取ったら配列に追加
            ss = append(ss, sub)
        case message := <- fromPub:
            // fromPub channelからメッセージを読み取ったら全てのサブスクリプションへ配送
            for _, sub := range ss {
                sub <- message
            }
        case err := <-doneSubscriptions:
            fmt.Println(err)
        }
    }
}

Subscription を配列で管理してるのをやめて、キーをClientIDとしたmapを使うことにする。

func Broker(fromPub <-chan *packet.Publish, subscriptions <-chan *Subscription, doneSubscriptions <-chan *DoneSubscriptionResult) {
    // サブスクリプションのmap
    var sMap map[string]*Subscription
    sMap = make(map[string]*Subscription)
    for {
        select {
        case sub := <-subscriptions:
            // channelからサブスクリプションを読み取ったらキーをclientIDとしてmapに追加
            sMap[sub.clientID] = sub
        case message := <-fromPub:
            for _, sub := range sMap {
                sub.pubToSub <- message
            }
        case err := <-doneSubscriptions:
            fmt.Println(err)
        }
    }
}

次は selectdoneSubscriptions channelから読み取った時の処理。 close とmapから削除する処理をすれば良い。

func Broker(fromPub <-chan *packet.Publish, subscriptions <-chan *Subscription, doneSubscriptions <-chan *DoneSubscriptionResult) {
    // サブスクリプションのmap
    var sMap map[string]*Subscription
    sMap = make(map[string]*Subscription)
    for {
        select {
        case sub := <-subscriptions:
            // channelからサブスクリプションを読み取ったらキーをclientIDとしてmapに追加
            sMap[sub.clientID] = sub
        case message := <-fromPub:
            for _, sub := range sMap {
                sub.pubToSub <- message
            }
        case done := <-doneSubscriptions:
            fmt.Printf("close subscription: %v\n", done.clientID)
            if done.err != nil {
                fmt.Println(done.err)
            }
            s, ok := sMap[sub.clientID]
            if ok {
                sub := s.(*Subscription)
                close(s.pubToSub)
                delete(sMap, done.clientID)
            }
        }
    }
}

sync.Mutex

複数goroutineで同じmapを操作しているのが気になる。mapは複数のgoroutineから触れるように設計されていない。スレッドセーフならぬgoroutineセーフではない。

sync.Mutexパッケージを使って同期をとりgoroutineセーフにしておく。

sync.Mutexを使うときは1つの構造体、というか型に閉じておくと良さそう。 typesubscriptionMap というstructを作る。

type subscriptionMap struct {
    mu    sync.Mutex
    value map[string]*Subscription
}

func newSubscriptionMap() *subscriptionMap {
    m := make(map[string]*Subscription)
    return &subscriptionMap{
        mu:    sync.Mutex{},
        value: m,
    }
}

func (m *subscriptionMap) get(clientID string) *Subscription {
    m.mu.Lock()
    s, ok := sMap[sub.clientID]
    m.mu.Unlock()
    if ok {
        return s.(*Subscription)
    }
    return nil
}

func (m *subscriptionMap) put(clientID string, s *Subscription) {
    m.mu.Lock()
    m.value[clientID] = s
    m.mu.Unlock()
}

func (m *subscriptionMap) delete(clientID string) {
    m.mu.Lock()
    delete(m.value, clientID)
    m.mu.Unlock()
}

と、ここまで書いて、mapに対する range はどうしたら良いのだろう?という疑問が...

sync.Map

ググってると sync.Map というのを発見!

Go Docを見ると sync.Map で扱う値はinterface{}であることが分かる。 type を使って自前のstructで包んだ方が良さそう。

結果的にbroker.goは以下のようになった。

type Subscription struct {
    clientID string
    pubToSub chan<- *packet.Publish
}

func NewSubscription(clientID string) (*Subscription, <-chan *packet.Publish) {
    pub := make(chan *packet.Publish)
    s := &Subscription{
        clientID: clientID,
        pubToSub: pub,
    }
    return s, pub
}

type DoneSubscriptionResult struct {
    clientID string
    err      error
}

func NewDoneSubscriptionResult(clientID string, err error) *DoneSubscriptionResult {
    return &DoneSubscriptionResult{clientID, err}
}

type subscriptionMap struct {
    syncMap sync.Map
}

func newSubscriptionMap() *subscriptionMap {
    return &subscriptionMap{}
}

func (m *subscriptionMap) get(clientID string) *Subscription {
    s, ok := m.syncMap.Load(clientID)
    if !ok {
        return nil
    }
    return s.(*Subscription)
}

func (m *subscriptionMap) put(clientID string, s *Subscription) {
    m.syncMap.Store(clientID, s)
}

func (m *subscriptionMap) delete(clientID string) {
    m.syncMap.Delete(clientID)
}

func (m *subscriptionMap) apply(f func(s *Subscription)) {
    m.syncMap.Range(func(k, v interface{}) bool {
        s := v.(*Subscription)
        f(s)
        return true
    })
}

func Broker(fromPub <-chan *packet.Publish, subscriptions <-chan *Subscription, doneSubscriptions <-chan *DoneSubscriptionResult) {
    // サブスクリプションのmap
    sMap := newSubscriptionMap()
    for {
        select {
        case sub := <-subscriptions:
            // channelからサブスクリプションを読み取ったらキーをclientIDとしてmapに追加
            sMap.put(sub.clientID, sub)
        case message := <-fromPub:
            // 全てのサブスクリプションにメッセージを配送
            sMap.apply(func(sub *Subscription) {
                sub.pubToSub <- message
            })
        case done := <-doneSubscriptions:
            fmt.Printf("close subscription: %v\n", done.clientID)
            if done.err != nil {
                fmt.Println(done.err)
            }
            sub := sMap.get(done.clientID)
            if sub != nil {
                close(sub.pubToSub)
                sMap.delete(done.clientID)
            }
        }
    }
}

動かす

動かしてみる。

mosquitto_subでサブスクライブした後、すぐにCtrl-Cで切断する。その後にmosquitto_pubでパブリッシュするとサーバーが以下を出力する。

write tcp 127.0.0.1:1883->127.0.0.1:59452: use of closed network connection

ここまでは最初と変わらないけど、もう一度パブリッシュするともう出力されない。これは1回目の Write でエラーが発生した時に、 Brokerサブスクリプションの削除ができたから。

これでhandleSubでエラーが発生した場合にBrokerでハンドリングする実装ができた!

おしまい

ただこれだとエラーが発生するまでサブスクリプションが削除されないし、 handleSub goroutineが残り続ける。本当はクライアントが切断したタイミングでサブスクリプションを削除したい。

次回はContextを使うことでこれを解決しようと思います。

今回の学び。