aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--catbus/catbus.go131
-rw-r--r--catbus/catbus_test.go206
-rw-r--r--cmd/catbus-actuator-wakeonlan/main.go4
3 files changed, 296 insertions, 45 deletions
diff --git a/catbus/catbus.go b/catbus/catbus.go
index 2538a2f..bb674c6 100644
--- a/catbus/catbus.go
+++ b/catbus/catbus.go
@@ -20,24 +20,27 @@ type (
Client struct {
mqtt mqtt.Client
- rebroadcastMu sync.Mutex
- rebroadcastByTopic map[string]*time.Timer
- rebroadcastPeriod time.Duration
- rebroadcastJitter time.Duration
- }
+ payloadByTopicMu sync.Mutex
+ payloadByTopic map[string][]byte
+
+ onconnectTimerByTopicMu sync.Mutex
+ onconnectTimerByTopic map[string]*time.Timer
+ onconnectDelay time.Duration
+ onconnectJitter time.Duration
+ }
ClientOptions struct {
DisconnectHandler func(*Client, error)
ConnectHandler func(*Client)
- // Rebroadcast previously seen values every RebroadcastPeriod ± [0,RebroadcastJitter).
- RebroadcastPeriod time.Duration
- RebroadcastJitter time.Duration
+ // Publish previously seen or default values on connecting after OnconnectDelay ± [0,OnconnectJitter).
+ OnconnectDelay time.Duration
+ OnconnectJitter time.Duration
- // RebroadcastDefaults are optional values to seed rebroadcasting if no prior values are seen.
+ // DefaultPayloadByTopic are optional values to publish on connect if no prior values are seen.
// E.g. unless we've been told otherwise, assume a device is off.
- RebroadcastDefaults map[string][]byte
+ DefaultPayloadByTopic map[string][]byte
}
// Retention is whether or not the MQTT broker should retain the message.
@@ -56,41 +59,47 @@ const (
)
const (
- DefaultRebroadcastPeriod = 1 * time.Minute
- DefaultRebroadcastJitter = 15 * time.Second
+ DefaultOnconnectDelay = 1 * time.Minute
+ DefaultOnconnectJitter = 15 * time.Second
)
func NewClient(brokerURI string, options ClientOptions) *Client {
client := &Client{
- rebroadcastByTopic: map[string]*time.Timer{},
- rebroadcastPeriod: DefaultRebroadcastPeriod,
- rebroadcastJitter: DefaultRebroadcastJitter,
+ payloadByTopic: map[string][]byte{},
+ onconnectTimerByTopic: map[string]*time.Timer{},
+
+ onconnectDelay: DefaultOnconnectDelay,
+ onconnectJitter: DefaultOnconnectJitter,
}
- if options.RebroadcastPeriod != 0 {
- client.rebroadcastPeriod = options.RebroadcastPeriod
+ if options.OnconnectDelay != 0 {
+ client.onconnectDelay = options.OnconnectDelay
}
- if options.RebroadcastJitter != 0 {
- client.rebroadcastJitter = options.RebroadcastJitter
+ if options.OnconnectJitter != 0 {
+ client.onconnectJitter = options.OnconnectJitter
}
- for topic, payload := range options.RebroadcastDefaults {
- // TODO: Allow users to set retention?
- client.rebroadcastLater(topic, Retain, payload)
+ for topic, payload := range options.DefaultPayloadByTopic {
+ client.payloadByTopic[topic] = payload
}
mqttOpts := mqtt.NewClientOptions()
mqttOpts.AddBroker(brokerURI)
mqttOpts.SetAutoReconnect(true)
- mqttOpts.SetConnectionLostHandler(func(c mqtt.Client, err error) {
- if options.DisconnectHandler != nil {
- options.DisconnectHandler(client, err)
- }
- })
mqttOpts.SetOnConnectHandler(func(c mqtt.Client) {
+ client.stopAllTimers()
+ client.startAllTimers()
+
if options.ConnectHandler != nil {
options.ConnectHandler(client)
}
})
+ mqttOpts.SetConnectionLostHandler(func(c mqtt.Client, err error) {
+ client.stopAllTimers()
+
+ if options.DisconnectHandler != nil {
+ options.DisconnectHandler(client, err)
+ }
+ })
client.mqtt = mqtt.NewClient(mqttOpts)
return client
@@ -107,7 +116,7 @@ func (c *Client) Connect() error {
// Subscribe subscribes to a Catbus MQTT topic.
func (c *Client) Subscribe(topic string, f MessageHandler) error {
return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, msg mqtt.Message) {
- c.rebroadcastLater(msg.Topic(), Retention(msg.Retained()), msg.Payload())
+ c.storePayload(msg.Topic(), Retention(msg.Retained()), msg.Payload())
f(c, msg)
}).Error()
@@ -115,32 +124,68 @@ func (c *Client) Subscribe(topic string, f MessageHandler) error {
// Publish publishes to a Catbus MQTT topic.
func (c *Client) Publish(topic string, retention Retention, payload []byte) error {
- c.rebroadcastLater(topic, retention, payload)
+ c.storePayload(topic, retention, payload)
return c.mqtt.Publish(topic, atLeastOnce, bool(retention), payload).Error()
}
-func (c *Client) rebroadcastLater(topic string, retention Retention, payload []byte) {
- c.rebroadcastMu.Lock()
- defer c.rebroadcastMu.Unlock()
+func (c *Client) jitteredOnconnectDelay() time.Duration {
+ jitter := time.Duration(rand.Intn(int(c.onconnectJitter)))
+ if rand.Intn(2) == 0 {
+ return c.onconnectDelay + jitter
+ }
+ return c.onconnectDelay - jitter
+}
- if timer := c.rebroadcastByTopic[topic]; timer != nil {
- _ = timer.Stop()
+func (c *Client) storePayload(topic string, retention Retention, payload []byte) {
+ c.payloadByTopicMu.Lock()
+ defer c.payloadByTopicMu.Unlock()
+
+ if _, ok := c.payloadByTopic[topic]; !ok && retention == DontRetain {
+ // If we don't have a copy, and the sender doesn't want it retained, don't retain it.
+ return
}
+ c.stopTimer(topic)
+
if len(payload) == 0 {
- // No payload => remove => don't rebroadcast.
+ delete(c.payloadByTopic, topic)
return
}
+ c.payloadByTopic[topic] = payload
+}
+func (c *Client) stopTimer(topic string) {
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
- c.rebroadcastByTopic[topic] = time.AfterFunc(c.rebroadcastDuration(), func() {
- _ = c.Publish(topic, retention, payload)
- })
+ if timer, ok := c.onconnectTimerByTopic[topic]; ok {
+ _ = timer.Stop()
+ }
+}
+func (c *Client) stopAllTimers() {
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
+
+ for _, timer := range c.onconnectTimerByTopic {
+ _ = timer.Stop()
+ }
}
-func (c *Client) rebroadcastDuration() time.Duration {
- jitter := time.Duration(rand.Intn(int(c.rebroadcastJitter)))
- if rand.Intn(1) == 0 {
- return c.rebroadcastPeriod + jitter
+func (c *Client) startAllTimers() {
+ c.payloadByTopicMu.Lock()
+ defer c.payloadByTopicMu.Unlock()
+
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
+
+ for topic := range c.payloadByTopic {
+ c.onconnectTimerByTopic[topic] = time.AfterFunc(c.jitteredOnconnectDelay(), func() {
+ c.payloadByTopicMu.Lock()
+ payload, ok := c.payloadByTopic[topic]
+ c.payloadByTopicMu.Unlock()
+ if !ok {
+ return
+ }
+ _ = c.Publish(topic, Retain, payload)
+ })
}
- return c.rebroadcastPeriod - jitter
}
diff --git a/catbus/catbus_test.go b/catbus/catbus_test.go
new file mode 100644
index 0000000..d07367b
--- /dev/null
+++ b/catbus/catbus_test.go
@@ -0,0 +1,206 @@
+// SPDX-FileCopyrightText: 2020 Ethel Morgan
+//
+// SPDX-License-Identifier: MIT
+
+package catbus
+
+import (
+ "fmt"
+ "log"
+ "reflect"
+ "testing"
+ "time"
+
+ mqtt "github.com/eclipse/paho.mqtt.golang"
+)
+
+type (
+ message struct {
+ retention Retention
+ payload []byte
+ }
+)
+
+func TestOnConnect(t *testing.T) {
+ tests := []struct {
+ payloadByTopic map[string][]byte
+ subscribe []string
+ receive map[string]message
+
+ want map[string][]byte
+ }{
+ {
+ payloadByTopic: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ want: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ },
+ {
+ payloadByTopic: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {Retain, []byte("on")},
+ },
+ want: map[string][]byte{
+ "tv/power": []byte("on"),
+ },
+ },
+ {
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {Retain, []byte("on")},
+ },
+ want: map[string][]byte{
+ "tv/power": []byte("on"),
+ },
+ },
+ {
+ payloadByTopic: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {DontRetain, []byte("on")},
+ },
+ want: map[string][]byte{
+ "tv/power": []byte("on"),
+ },
+ },
+ {
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {DontRetain, []byte("on")},
+ },
+ want: map[string][]byte{},
+ },
+ {
+ payloadByTopic: map[string][]byte{
+ "tv/power": []byte("off"),
+ },
+ subscribe: []string{
+ "tv/power",
+ },
+ receive: map[string]message{
+ "tv/power": {DontRetain, []byte{}},
+ },
+ want: map[string][]byte{},
+ },
+ }
+
+ for i, tt := range tests {
+ fakeMQTT := &fakeMQTT{
+ callbackByTopic: map[string]mqtt.MessageHandler{},
+ payloadByTopic: map[string][]byte{},
+ }
+
+ catbus := &Client{
+ mqtt: fakeMQTT,
+ payloadByTopic: map[string][]byte{},
+ onconnectTimerByTopic: map[string]*time.Timer{},
+ onconnectDelay: 1 * time.Millisecond,
+ onconnectJitter: 1,
+ }
+ if tt.payloadByTopic != nil {
+ catbus.payloadByTopic = tt.payloadByTopic
+ }
+
+ for _, topic := range tt.subscribe {
+ catbus.Subscribe(topic, func(_ *Client, _ Message) {})
+ }
+ for topic, message := range tt.receive {
+ fakeMQTT.send(topic, message.retention, message.payload)
+ }
+
+ catbus.stopAllTimers()
+ catbus.startAllTimers()
+
+ // TODO: replace with proper channel signaling or sth.
+ time.Sleep(1 * time.Second)
+
+ got := fakeMQTT.payloadByTopic
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("[%d]: got %v, want %v", i, got, tt.want)
+ }
+ }
+}
+
+type (
+ fakeMQTT struct {
+ mqtt.Client
+
+ callbackByTopic map[string]mqtt.MessageHandler
+ payloadByTopic map[string][]byte
+ }
+
+ fakeMessage struct {
+ mqtt.Message
+
+ topic string
+ retained bool
+ payload []byte
+ }
+
+ fakeToken struct{}
+)
+
+func (f *fakeMQTT) Publish(topic string, qos byte, retain bool, payload interface{}) mqtt.Token {
+ bytes, ok := payload.([]byte)
+ if !ok {
+ panic(fmt.Sprintf("expected type []byte, got %v", reflect.TypeOf(payload)))
+ }
+
+ log.Printf("topic %q payload %s", topic, payload)
+ f.payloadByTopic[topic] = bytes
+ return &fakeToken{}
+}
+func (f *fakeMQTT) Subscribe(topic string, qos byte, callback mqtt.MessageHandler) mqtt.Token {
+ f.callbackByTopic[topic] = callback
+
+ return &fakeToken{}
+}
+func (f *fakeMQTT) send(topic string, retention Retention, payload []byte) {
+ // if retention == Retain {
+ // f.payloadByTopic[topic] = payload
+ // }
+
+ if callback, ok := f.callbackByTopic[topic]; ok {
+ msg := &fakeMessage{
+ topic: topic,
+ retained: bool(retention),
+ payload: payload,
+ }
+ callback(f, msg)
+ }
+}
+
+func (f *fakeMessage) Topic() string {
+ return f.topic
+}
+func (f *fakeMessage) Payload() []byte {
+ return f.payload
+}
+func (f *fakeMessage) Retained() bool {
+ return f.retained
+}
+
+func (_ *fakeToken) Wait() bool {
+ return false
+}
+func (_ *fakeToken) WaitTimeout(_ time.Duration) bool {
+ return false
+}
+func (_ *fakeToken) Error() error {
+ return nil
+}
diff --git a/cmd/catbus-actuator-wakeonlan/main.go b/cmd/catbus-actuator-wakeonlan/main.go
index 6f4e62b..2649206 100644
--- a/cmd/catbus-actuator-wakeonlan/main.go
+++ b/cmd/catbus-actuator-wakeonlan/main.go
@@ -74,9 +74,9 @@ func main() {
},
}
- catbusOptions.RebroadcastDefaults = map[string][]byte{}
+ catbusOptions.DefaultPayloadByTopic = map[string][]byte{}
for topic := range config.MACsByTopic {
- catbusOptions.RebroadcastDefaults[topic] = []byte("off")
+ catbusOptions.DefaultPayloadByTopic[topic] = []byte("off")
}
catbus := catbus.NewClient(config.BrokerURI, catbusOptions)