diff options
author | Ethel Morgan <eth@ethulhu.co.uk> | 2020-06-23 15:52:22 +0100 |
---|---|---|
committer | Ethel Morgan <eth@ethulhu.co.uk> | 2020-06-23 15:52:22 +0100 |
commit | e6e67a4015a5de5ffb35e938776db7afe6382446 (patch) | |
tree | dd9095118df4418742da62dd20878f4ddad9b96d | |
parent | a9ffdec5c266682dc3d59145fb57f0d763749c8d (diff) |
import catbus library from catbus-wakeonlan
-rw-r--r-- | catbus.go | 191 | ||||
-rw-r--r-- | catbus_test.go | 206 | ||||
-rw-r--r-- | go.mod | 12 | ||||
-rw-r--r-- | go.sum | 8 |
4 files changed, 417 insertions, 0 deletions
diff --git a/catbus.go b/catbus.go new file mode 100644 index 0000000..bb674c6 --- /dev/null +++ b/catbus.go @@ -0,0 +1,191 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +// Package catbus is a convenience wrapper around MQTT for use with Catbus. +package catbus + +import ( + "math/rand" + "sync" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" +) + +type ( + Message = mqtt.Message + MessageHandler = func(*Client, Message) + + Client struct { + mqtt mqtt.Client + + payloadByTopicMu sync.Mutex + payloadByTopic map[string][]byte + + onconnectTimerByTopicMu sync.Mutex + onconnectTimerByTopic map[string]*time.Timer + + onconnectDelay time.Duration + onconnectJitter time.Duration + } + + ClientOptions struct { + DisconnectHandler func(*Client, error) + ConnectHandler func(*Client) + + // Publish previously seen or default values on connecting after OnconnectDelay ± [0,OnconnectJitter). + OnconnectDelay time.Duration + OnconnectJitter time.Duration + + // DefaultPayloadByTopic are optional values to publish on connect if no prior values are seen. + // E.g. unless we've been told otherwise, assume a device is off. + DefaultPayloadByTopic map[string][]byte + } + + // Retention is whether or not the MQTT broker should retain the message. + Retention bool +) + +const ( + atMostOnce byte = iota + atLeastOnce + exactlyOnce +) + +const ( + Retain = Retention(true) + DontRetain = Retention(false) +) + +const ( + DefaultOnconnectDelay = 1 * time.Minute + DefaultOnconnectJitter = 15 * time.Second +) + +func NewClient(brokerURI string, options ClientOptions) *Client { + client := &Client{ + payloadByTopic: map[string][]byte{}, + onconnectTimerByTopic: map[string]*time.Timer{}, + + onconnectDelay: DefaultOnconnectDelay, + onconnectJitter: DefaultOnconnectJitter, + } + + if options.OnconnectDelay != 0 { + client.onconnectDelay = options.OnconnectDelay + } + if options.OnconnectJitter != 0 { + client.onconnectJitter = options.OnconnectJitter + } + for topic, payload := range options.DefaultPayloadByTopic { + client.payloadByTopic[topic] = payload + } + + mqttOpts := mqtt.NewClientOptions() + mqttOpts.AddBroker(brokerURI) + mqttOpts.SetAutoReconnect(true) + mqttOpts.SetOnConnectHandler(func(c mqtt.Client) { + client.stopAllTimers() + client.startAllTimers() + + if options.ConnectHandler != nil { + options.ConnectHandler(client) + } + }) + mqttOpts.SetConnectionLostHandler(func(c mqtt.Client, err error) { + client.stopAllTimers() + + if options.DisconnectHandler != nil { + options.DisconnectHandler(client, err) + } + }) + client.mqtt = mqtt.NewClient(mqttOpts) + + return client +} + +// Connect connects to the Catbus MQTT broker and blocks forever. +func (c *Client) Connect() error { + if err := c.mqtt.Connect().Error(); err != nil { + return err + } + select {} +} + +// Subscribe subscribes to a Catbus MQTT topic. +func (c *Client) Subscribe(topic string, f MessageHandler) error { + return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, msg mqtt.Message) { + c.storePayload(msg.Topic(), Retention(msg.Retained()), msg.Payload()) + + f(c, msg) + }).Error() +} + +// Publish publishes to a Catbus MQTT topic. +func (c *Client) Publish(topic string, retention Retention, payload []byte) error { + c.storePayload(topic, retention, payload) + + return c.mqtt.Publish(topic, atLeastOnce, bool(retention), payload).Error() +} + +func (c *Client) jitteredOnconnectDelay() time.Duration { + jitter := time.Duration(rand.Intn(int(c.onconnectJitter))) + if rand.Intn(2) == 0 { + return c.onconnectDelay + jitter + } + return c.onconnectDelay - jitter +} + +func (c *Client) storePayload(topic string, retention Retention, payload []byte) { + c.payloadByTopicMu.Lock() + defer c.payloadByTopicMu.Unlock() + + if _, ok := c.payloadByTopic[topic]; !ok && retention == DontRetain { + // If we don't have a copy, and the sender doesn't want it retained, don't retain it. + return + } + + c.stopTimer(topic) + + if len(payload) == 0 { + delete(c.payloadByTopic, topic) + return + } + c.payloadByTopic[topic] = payload +} +func (c *Client) stopTimer(topic string) { + c.onconnectTimerByTopicMu.Lock() + defer c.onconnectTimerByTopicMu.Unlock() + + if timer, ok := c.onconnectTimerByTopic[topic]; ok { + _ = timer.Stop() + } +} +func (c *Client) stopAllTimers() { + c.onconnectTimerByTopicMu.Lock() + defer c.onconnectTimerByTopicMu.Unlock() + + for _, timer := range c.onconnectTimerByTopic { + _ = timer.Stop() + } +} +func (c *Client) startAllTimers() { + c.payloadByTopicMu.Lock() + defer c.payloadByTopicMu.Unlock() + + c.onconnectTimerByTopicMu.Lock() + defer c.onconnectTimerByTopicMu.Unlock() + + for topic := range c.payloadByTopic { + c.onconnectTimerByTopic[topic] = time.AfterFunc(c.jitteredOnconnectDelay(), func() { + c.payloadByTopicMu.Lock() + payload, ok := c.payloadByTopic[topic] + c.payloadByTopicMu.Unlock() + if !ok { + return + } + _ = c.Publish(topic, Retain, payload) + }) + } +} diff --git a/catbus_test.go b/catbus_test.go new file mode 100644 index 0000000..d07367b --- /dev/null +++ b/catbus_test.go @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +package catbus + +import ( + "fmt" + "log" + "reflect" + "testing" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" +) + +type ( + message struct { + retention Retention + payload []byte + } +) + +func TestOnConnect(t *testing.T) { + tests := []struct { + payloadByTopic map[string][]byte + subscribe []string + receive map[string]message + + want map[string][]byte + }{ + { + payloadByTopic: map[string][]byte{ + "tv/power": []byte("off"), + }, + want: map[string][]byte{ + "tv/power": []byte("off"), + }, + }, + { + payloadByTopic: map[string][]byte{ + "tv/power": []byte("off"), + }, + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {Retain, []byte("on")}, + }, + want: map[string][]byte{ + "tv/power": []byte("on"), + }, + }, + { + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {Retain, []byte("on")}, + }, + want: map[string][]byte{ + "tv/power": []byte("on"), + }, + }, + { + payloadByTopic: map[string][]byte{ + "tv/power": []byte("off"), + }, + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {DontRetain, []byte("on")}, + }, + want: map[string][]byte{ + "tv/power": []byte("on"), + }, + }, + { + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {DontRetain, []byte("on")}, + }, + want: map[string][]byte{}, + }, + { + payloadByTopic: map[string][]byte{ + "tv/power": []byte("off"), + }, + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {DontRetain, []byte{}}, + }, + want: map[string][]byte{}, + }, + } + + for i, tt := range tests { + fakeMQTT := &fakeMQTT{ + callbackByTopic: map[string]mqtt.MessageHandler{}, + payloadByTopic: map[string][]byte{}, + } + + catbus := &Client{ + mqtt: fakeMQTT, + payloadByTopic: map[string][]byte{}, + onconnectTimerByTopic: map[string]*time.Timer{}, + onconnectDelay: 1 * time.Millisecond, + onconnectJitter: 1, + } + if tt.payloadByTopic != nil { + catbus.payloadByTopic = tt.payloadByTopic + } + + for _, topic := range tt.subscribe { + catbus.Subscribe(topic, func(_ *Client, _ Message) {}) + } + for topic, message := range tt.receive { + fakeMQTT.send(topic, message.retention, message.payload) + } + + catbus.stopAllTimers() + catbus.startAllTimers() + + // TODO: replace with proper channel signaling or sth. + time.Sleep(1 * time.Second) + + got := fakeMQTT.payloadByTopic + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("[%d]: got %v, want %v", i, got, tt.want) + } + } +} + +type ( + fakeMQTT struct { + mqtt.Client + + callbackByTopic map[string]mqtt.MessageHandler + payloadByTopic map[string][]byte + } + + fakeMessage struct { + mqtt.Message + + topic string + retained bool + payload []byte + } + + fakeToken struct{} +) + +func (f *fakeMQTT) Publish(topic string, qos byte, retain bool, payload interface{}) mqtt.Token { + bytes, ok := payload.([]byte) + if !ok { + panic(fmt.Sprintf("expected type []byte, got %v", reflect.TypeOf(payload))) + } + + log.Printf("topic %q payload %s", topic, payload) + f.payloadByTopic[topic] = bytes + return &fakeToken{} +} +func (f *fakeMQTT) Subscribe(topic string, qos byte, callback mqtt.MessageHandler) mqtt.Token { + f.callbackByTopic[topic] = callback + + return &fakeToken{} +} +func (f *fakeMQTT) send(topic string, retention Retention, payload []byte) { + // if retention == Retain { + // f.payloadByTopic[topic] = payload + // } + + if callback, ok := f.callbackByTopic[topic]; ok { + msg := &fakeMessage{ + topic: topic, + retained: bool(retention), + payload: payload, + } + callback(f, msg) + } +} + +func (f *fakeMessage) Topic() string { + return f.topic +} +func (f *fakeMessage) Payload() []byte { + return f.payload +} +func (f *fakeMessage) Retained() bool { + return f.retained +} + +func (_ *fakeToken) Wait() bool { + return false +} +func (_ *fakeToken) WaitTimeout(_ time.Duration) bool { + return false +} +func (_ *fakeToken) Error() error { + return nil +} @@ -0,0 +1,12 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +module go.eth.moe/catbus + +go 1.14 + +require ( + github.com/eclipse/paho.mqtt.golang v1.2.0 + golang.org/x/net v0.0.0-20200602114024-627f9648deb9 // indirect +) @@ -0,0 +1,8 @@ +github.com/eclipse/paho.mqtt.golang v1.2.0 h1:1F8mhG9+aO5/xpdtFkW4SxOJB67ukuDC3t2y2qayIX0= +github.com/eclipse/paho.mqtt.golang v1.2.0/go.mod h1:H9keYFcgq3Qr5OUJm/JZI/i6U7joQ8SYLhZwfeOo6Ts= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM= +golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= |