aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEthel Morgan <eth@ethulhu.co.uk>2020-06-23 15:52:22 +0100
committerEthel Morgan <eth@ethulhu.co.uk>2020-06-23 15:52:22 +0100
commite6e67a4015a5de5ffb35e938776db7afe6382446 (patch)
treedd9095118df4418742da62dd20878f4ddad9b96d
parenta9ffdec5c266682dc3d59145fb57f0d763749c8d (diff)
import catbus library from catbus-wakeonlan
-rw-r--r--catbus.go191
-rw-r--r--catbus_test.go206
-rw-r--r--go.mod12
-rw-r--r--go.sum8
4 files changed, 417 insertions, 0 deletions
diff --git a/catbus.go b/catbus.go
new file mode 100644
index 0000000..bb674c6
--- /dev/null
+++ b/catbus.go
@@ -0,0 +1,191 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+// Package catbus is a convenience wrapper around MQTT for use with Catbus.
+package catbus
+
+import (
+ "math/rand"
+ "sync"
+ "time"
+
+ mqtt "github.com/eclipse/paho.mqtt.golang"
+)
+
+type (
+ Message = mqtt.Message
+ MessageHandler = func(*Client, Message)
+
+ Client struct {
+ mqtt mqtt.Client
+
+ payloadByTopicMu sync.Mutex
+ payloadByTopic map[string][]byte
+
+ onconnectTimerByTopicMu sync.Mutex
+ onconnectTimerByTopic map[string]*time.Timer
+
+ onconnectDelay time.Duration
+ onconnectJitter time.Duration
+ }
+
+ ClientOptions struct {
+ DisconnectHandler func(*Client, error)
+ ConnectHandler func(*Client)
+
+ // Publish previously seen or default values on connecting after OnconnectDelay ± [0,OnconnectJitter).
+ OnconnectDelay time.Duration
+ OnconnectJitter time.Duration
+
+ // DefaultPayloadByTopic are optional values to publish on connect if no prior values are seen.
+ // E.g. unless we've been told otherwise, assume a device is off.
+ DefaultPayloadByTopic map[string][]byte
+ }
+
+ // Retention is whether or not the MQTT broker should retain the message.
+ Retention bool
+)
+
+const (
+ atMostOnce byte = iota
+ atLeastOnce
+ exactlyOnce
+)
+
+const (
+ Retain = Retention(true)
+ DontRetain = Retention(false)
+)
+
+const (
+ DefaultOnconnectDelay = 1 * time.Minute
+ DefaultOnconnectJitter = 15 * time.Second
+)
+
+func NewClient(brokerURI string, options ClientOptions) *Client {
+ client := &Client{
+ payloadByTopic: map[string][]byte{},
+ onconnectTimerByTopic: map[string]*time.Timer{},
+
+ onconnectDelay: DefaultOnconnectDelay,
+ onconnectJitter: DefaultOnconnectJitter,
+ }
+
+ if options.OnconnectDelay != 0 {
+ client.onconnectDelay = options.OnconnectDelay
+ }
+ if options.OnconnectJitter != 0 {
+ client.onconnectJitter = options.OnconnectJitter
+ }
+ for topic, payload := range options.DefaultPayloadByTopic {
+ client.payloadByTopic[topic] = payload
+ }
+
+ mqttOpts := mqtt.NewClientOptions()
+ mqttOpts.AddBroker(brokerURI)
+ mqttOpts.SetAutoReconnect(true)
+ mqttOpts.SetOnConnectHandler(func(c mqtt.Client) {
+ client.stopAllTimers()
+ client.startAllTimers()
+
+ if options.ConnectHandler != nil {
+ options.ConnectHandler(client)
+ }
+ })
+ mqttOpts.SetConnectionLostHandler(func(c mqtt.Client, err error) {
+ client.stopAllTimers()
+
+ if options.DisconnectHandler != nil {
+ options.DisconnectHandler(client, err)
+ }
+ })
+ client.mqtt = mqtt.NewClient(mqttOpts)
+
+ return client
+}
+
+// Connect connects to the Catbus MQTT broker and blocks forever.
+func (c *Client) Connect() error {
+ if err := c.mqtt.Connect().Error(); err != nil {
+ return err
+ }
+ select {}
+}
+
+// Subscribe subscribes to a Catbus MQTT topic.
+func (c *Client) Subscribe(topic string, f MessageHandler) error {
+ return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, msg mqtt.Message) {
+ c.storePayload(msg.Topic(), Retention(msg.Retained()), msg.Payload())
+
+ f(c, msg)
+ }).Error()
+}
+
+// Publish publishes to a Catbus MQTT topic.
+func (c *Client) Publish(topic string, retention Retention, payload []byte) error {
+ c.storePayload(topic, retention, payload)
+
+ return c.mqtt.Publish(topic, atLeastOnce, bool(retention), payload).Error()
+}
+
+func (c *Client) jitteredOnconnectDelay() time.Duration {
+ jitter := time.Duration(rand.Intn(int(c.onconnectJitter)))
+ if rand.Intn(2) == 0 {
+ return c.onconnectDelay + jitter
+ }
+ return c.onconnectDelay - jitter
+}
+
+func (c *Client) storePayload(topic string, retention Retention, payload []byte) {
+ c.payloadByTopicMu.Lock()
+ defer c.payloadByTopicMu.Unlock()
+
+ if _, ok := c.payloadByTopic[topic]; !ok && retention == DontRetain {
+ // If we don't have a copy, and the sender doesn't want it retained, don't retain it.
+ return
+ }
+
+ c.stopTimer(topic)
+
+ if len(payload) == 0 {
+ delete(c.payloadByTopic, topic)
+ return
+ }
+ c.payloadByTopic[topic] = payload
+}
+func (c *Client) stopTimer(topic string) {
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
+
+ if timer, ok := c.onconnectTimerByTopic[topic]; ok {
+ _ = timer.Stop()
+ }
+}
+func (c *Client) stopAllTimers() {
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
+
+ for _, timer := range c.onconnectTimerByTopic {
+ _ = timer.Stop()
+ }
+}
+func (c *Client) startAllTimers() {
+ c.payloadByTopicMu.Lock()
+ defer c.payloadByTopicMu.Unlock()
+
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
+
+ for topic := range c.payloadByTopic {
+ c.onconnectTimerByTopic[topic] = time.AfterFunc(c.jitteredOnconnectDelay(), func() {
+ c.payloadByTopicMu.Lock()
+ payload, ok := c.payloadByTopic[topic]
+ c.payloadByTopicMu.Unlock()
+ if !ok {
+ return
+ }
+ _ = c.Publish(topic, Retain, payload)
+ })
+ }
+}
diff --git a/catbus_test.go b/catbus_test.go
new file mode 100644
index 0000000..d07367b
--- /dev/null
+++ b/catbus_test.go
@@ -0,0 +1,206 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+package catbus
+
+import (
+ "fmt"
+ "log"
+ "reflect"
+ "testing"
+ "time"
+
+ mqtt "github.com/eclipse/paho.mqtt.golang"
+)
+
+type (
+ message struct {
+ retention Retention
+ payload []byte
+ }
+)
+
+func TestOnConnect(t *testing.T) {
+ tests := []struct {
+ payloadByTopic map[string][]byte
+ subscribe []string
+ receive map[string]message
+
+ want map[string][]byte
+ }{
+ {
+ payloadByTopic: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ want: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ },
+ {
+ payloadByTopic: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {Retain, []byte("on")},
+ },
+ want: map[string][]byte{
+ "tv/power": []byte("on"),
+ },
+ },
+ {
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {Retain, []byte("on")},
+ },
+ want: map[string][]byte{
+ "tv/power": []byte("on"),
+ },
+ },
+ {
+ payloadByTopic: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {DontRetain, []byte("on")},
+ },
+ want: map[string][]byte{
+ "tv/power": []byte("on"),
+ },
+ },
+ {
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {DontRetain, []byte("on")},
+ },
+ want: map[string][]byte{},
+ },
+ {
+ payloadByTopic: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {DontRetain, []byte{}},
+ },
+ want: map[string][]byte{},
+ },
+ }
+
+ for i, tt := range tests {
+ fakeMQTT := &fakeMQTT{
+ callbackByTopic: map[string]mqtt.MessageHandler{},
+ payloadByTopic: map[string][]byte{},
+ }
+
+ catbus := &Client{
+ mqtt: fakeMQTT,
+ payloadByTopic: map[string][]byte{},
+ onconnectTimerByTopic: map[string]*time.Timer{},
+ onconnectDelay: 1 * time.Millisecond,
+ onconnectJitter: 1,
+ }
+ if tt.payloadByTopic != nil {
+ catbus.payloadByTopic = tt.payloadByTopic
+ }
+
+ for _, topic := range tt.subscribe {
+ catbus.Subscribe(topic, func(_ *Client, _ Message) {})
+ }
+ for topic, message := range tt.receive {
+ fakeMQTT.send(topic, message.retention, message.payload)
+ }
+
+ catbus.stopAllTimers()
+ catbus.startAllTimers()
+
+ // TODO: replace with proper channel signaling or sth.
+ time.Sleep(1 * time.Second)
+
+ got := fakeMQTT.payloadByTopic
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("[%d]: got %v, want %v", i, got, tt.want)
+ }
+ }
+}
+
+type (
+ fakeMQTT struct {
+ mqtt.Client
+
+ callbackByTopic map[string]mqtt.MessageHandler
+ payloadByTopic map[string][]byte
+ }
+
+ fakeMessage struct {
+ mqtt.Message
+
+ topic string
+ retained bool
+ payload []byte
+ }
+
+ fakeToken struct{}
+)
+
+func (f *fakeMQTT) Publish(topic string, qos byte, retain bool, payload interface{}) mqtt.Token {
+ bytes, ok := payload.([]byte)
+ if !ok {
+ panic(fmt.Sprintf("expected type []byte, got %v", reflect.TypeOf(payload)))
+ }
+
+ log.Printf("topic %q payload %s", topic, payload)
+ f.payloadByTopic[topic] = bytes
+ return &fakeToken{}
+}
+func (f *fakeMQTT) Subscribe(topic string, qos byte, callback mqtt.MessageHandler) mqtt.Token {
+ f.callbackByTopic[topic] = callback
+
+ return &fakeToken{}
+}
+func (f *fakeMQTT) send(topic string, retention Retention, payload []byte) {
+ // if retention == Retain {
+ // f.payloadByTopic[topic] = payload
+ // }
+
+ if callback, ok := f.callbackByTopic[topic]; ok {
+ msg := &fakeMessage{
+ topic: topic,
+ retained: bool(retention),
+ payload: payload,
+ }
+ callback(f, msg)
+ }
+}
+
+func (f *fakeMessage) Topic() string {
+ return f.topic
+}
+func (f *fakeMessage) Payload() []byte {
+ return f.payload
+}
+func (f *fakeMessage) Retained() bool {
+ return f.retained
+}
+
+func (_ *fakeToken) Wait() bool {
+ return false
+}
+func (_ *fakeToken) WaitTimeout(_ time.Duration) bool {
+ return false
+}
+func (_ *fakeToken) Error() error {
+ return nil
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..bf3ebcd
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,12 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+module go.eth.moe/catbus
+
+go 1.14
+
+require (
+ github.com/eclipse/paho.mqtt.golang v1.2.0
+ golang.org/x/net v0.0.0-20200602114024-627f9648deb9 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..d996e7c
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,8 @@
+github.com/eclipse/paho.mqtt.golang v1.2.0 h1:1F8mhG9+aO5/xpdtFkW4SxOJB67ukuDC3t2y2qayIX0=
+github.com/eclipse/paho.mqtt.golang v1.2.0/go.mod h1:H9keYFcgq3Qr5OUJm/JZI/i6U7joQ8SYLhZwfeOo6Ts=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM=
+golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=