diff options
Diffstat (limited to '')
-rw-r--r-- | catbus/catbus.go | 131 | ||||
-rw-r--r-- | catbus/catbus_test.go | 206 |
2 files changed, 294 insertions, 43 deletions
diff --git a/catbus/catbus.go b/catbus/catbus.go index 2538a2f..bb674c6 100644 --- a/catbus/catbus.go +++ b/catbus/catbus.go @@ -20,24 +20,27 @@ type ( Client struct { mqtt mqtt.Client - rebroadcastMu sync.Mutex - rebroadcastByTopic map[string]*time.Timer - rebroadcastPeriod time.Duration - rebroadcastJitter time.Duration - } + payloadByTopicMu sync.Mutex + payloadByTopic map[string][]byte + + onconnectTimerByTopicMu sync.Mutex + onconnectTimerByTopic map[string]*time.Timer + onconnectDelay time.Duration + onconnectJitter time.Duration + } ClientOptions struct { DisconnectHandler func(*Client, error) ConnectHandler func(*Client) - // Rebroadcast previously seen values every RebroadcastPeriod ± [0,RebroadcastJitter). - RebroadcastPeriod time.Duration - RebroadcastJitter time.Duration + // Publish previously seen or default values on connecting after OnconnectDelay ± [0,OnconnectJitter). + OnconnectDelay time.Duration + OnconnectJitter time.Duration - // RebroadcastDefaults are optional values to seed rebroadcasting if no prior values are seen. + // DefaultPayloadByTopic are optional values to publish on connect if no prior values are seen. // E.g. unless we've been told otherwise, assume a device is off. - RebroadcastDefaults map[string][]byte + DefaultPayloadByTopic map[string][]byte } // Retention is whether or not the MQTT broker should retain the message. @@ -56,41 +59,47 @@ const ( ) const ( - DefaultRebroadcastPeriod = 1 * time.Minute - DefaultRebroadcastJitter = 15 * time.Second + DefaultOnconnectDelay = 1 * time.Minute + DefaultOnconnectJitter = 15 * time.Second ) func NewClient(brokerURI string, options ClientOptions) *Client { client := &Client{ - rebroadcastByTopic: map[string]*time.Timer{}, - rebroadcastPeriod: DefaultRebroadcastPeriod, - rebroadcastJitter: DefaultRebroadcastJitter, + payloadByTopic: map[string][]byte{}, + onconnectTimerByTopic: map[string]*time.Timer{}, + + onconnectDelay: DefaultOnconnectDelay, + onconnectJitter: DefaultOnconnectJitter, } - if options.RebroadcastPeriod != 0 { - client.rebroadcastPeriod = options.RebroadcastPeriod + if options.OnconnectDelay != 0 { + client.onconnectDelay = options.OnconnectDelay } - if options.RebroadcastJitter != 0 { - client.rebroadcastJitter = options.RebroadcastJitter + if options.OnconnectJitter != 0 { + client.onconnectJitter = options.OnconnectJitter } - for topic, payload := range options.RebroadcastDefaults { - // TODO: Allow users to set retention? - client.rebroadcastLater(topic, Retain, payload) + for topic, payload := range options.DefaultPayloadByTopic { + client.payloadByTopic[topic] = payload } mqttOpts := mqtt.NewClientOptions() mqttOpts.AddBroker(brokerURI) mqttOpts.SetAutoReconnect(true) - mqttOpts.SetConnectionLostHandler(func(c mqtt.Client, err error) { - if options.DisconnectHandler != nil { - options.DisconnectHandler(client, err) - } - }) mqttOpts.SetOnConnectHandler(func(c mqtt.Client) { + client.stopAllTimers() + client.startAllTimers() + if options.ConnectHandler != nil { options.ConnectHandler(client) } }) + mqttOpts.SetConnectionLostHandler(func(c mqtt.Client, err error) { + client.stopAllTimers() + + if options.DisconnectHandler != nil { + options.DisconnectHandler(client, err) + } + }) client.mqtt = mqtt.NewClient(mqttOpts) return client @@ -107,7 +116,7 @@ func (c *Client) Connect() error { // Subscribe subscribes to a Catbus MQTT topic. func (c *Client) Subscribe(topic string, f MessageHandler) error { return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, msg mqtt.Message) { - c.rebroadcastLater(msg.Topic(), Retention(msg.Retained()), msg.Payload()) + c.storePayload(msg.Topic(), Retention(msg.Retained()), msg.Payload()) f(c, msg) }).Error() @@ -115,32 +124,68 @@ func (c *Client) Subscribe(topic string, f MessageHandler) error { // Publish publishes to a Catbus MQTT topic. func (c *Client) Publish(topic string, retention Retention, payload []byte) error { - c.rebroadcastLater(topic, retention, payload) + c.storePayload(topic, retention, payload) return c.mqtt.Publish(topic, atLeastOnce, bool(retention), payload).Error() } -func (c *Client) rebroadcastLater(topic string, retention Retention, payload []byte) { - c.rebroadcastMu.Lock() - defer c.rebroadcastMu.Unlock() +func (c *Client) jitteredOnconnectDelay() time.Duration { + jitter := time.Duration(rand.Intn(int(c.onconnectJitter))) + if rand.Intn(2) == 0 { + return c.onconnectDelay + jitter + } + return c.onconnectDelay - jitter +} - if timer := c.rebroadcastByTopic[topic]; timer != nil { - _ = timer.Stop() +func (c *Client) storePayload(topic string, retention Retention, payload []byte) { + c.payloadByTopicMu.Lock() + defer c.payloadByTopicMu.Unlock() + + if _, ok := c.payloadByTopic[topic]; !ok && retention == DontRetain { + // If we don't have a copy, and the sender doesn't want it retained, don't retain it. + return } + c.stopTimer(topic) + if len(payload) == 0 { - // No payload => remove => don't rebroadcast. + delete(c.payloadByTopic, topic) return } + c.payloadByTopic[topic] = payload +} +func (c *Client) stopTimer(topic string) { + c.onconnectTimerByTopicMu.Lock() + defer c.onconnectTimerByTopicMu.Unlock() - c.rebroadcastByTopic[topic] = time.AfterFunc(c.rebroadcastDuration(), func() { - _ = c.Publish(topic, retention, payload) - }) + if timer, ok := c.onconnectTimerByTopic[topic]; ok { + _ = timer.Stop() + } +} +func (c *Client) stopAllTimers() { + c.onconnectTimerByTopicMu.Lock() + defer c.onconnectTimerByTopicMu.Unlock() + + for _, timer := range c.onconnectTimerByTopic { + _ = timer.Stop() + } } -func (c *Client) rebroadcastDuration() time.Duration { - jitter := time.Duration(rand.Intn(int(c.rebroadcastJitter))) - if rand.Intn(1) == 0 { - return c.rebroadcastPeriod + jitter +func (c *Client) startAllTimers() { + c.payloadByTopicMu.Lock() + defer c.payloadByTopicMu.Unlock() + + c.onconnectTimerByTopicMu.Lock() + defer c.onconnectTimerByTopicMu.Unlock() + + for topic := range c.payloadByTopic { + c.onconnectTimerByTopic[topic] = time.AfterFunc(c.jitteredOnconnectDelay(), func() { + c.payloadByTopicMu.Lock() + payload, ok := c.payloadByTopic[topic] + c.payloadByTopicMu.Unlock() + if !ok { + return + } + _ = c.Publish(topic, Retain, payload) + }) } - return c.rebroadcastPeriod - jitter } diff --git a/catbus/catbus_test.go b/catbus/catbus_test.go new file mode 100644 index 0000000..d07367b --- /dev/null +++ b/catbus/catbus_test.go @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: 2020 Ethel Morgan +// +// SPDX-License-Identifier: MIT + +package catbus + +import ( + "fmt" + "log" + "reflect" + "testing" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" +) + +type ( + message struct { + retention Retention + payload []byte + } +) + +func TestOnConnect(t *testing.T) { + tests := []struct { + payloadByTopic map[string][]byte + subscribe []string + receive map[string]message + + want map[string][]byte + }{ + { + payloadByTopic: map[string][]byte{ + "tv/power": []byte("off"), + }, + want: map[string][]byte{ + "tv/power": []byte("off"), + }, + }, + { + payloadByTopic: map[string][]byte{ + "tv/power": []byte("off"), + }, + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {Retain, []byte("on")}, + }, + want: map[string][]byte{ + "tv/power": []byte("on"), + }, + }, + { + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {Retain, []byte("on")}, + }, + want: map[string][]byte{ + "tv/power": []byte("on"), + }, + }, + { + payloadByTopic: map[string][]byte{ + "tv/power": []byte("off"), + }, + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {DontRetain, []byte("on")}, + }, + want: map[string][]byte{ + "tv/power": []byte("on"), + }, + }, + { + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {DontRetain, []byte("on")}, + }, + want: map[string][]byte{}, + }, + { + payloadByTopic: map[string][]byte{ + "tv/power": []byte("off"), + }, + subscribe: []string{ + "tv/power", + }, + receive: map[string]message{ + "tv/power": {DontRetain, []byte{}}, + }, + want: map[string][]byte{}, + }, + } + + for i, tt := range tests { + fakeMQTT := &fakeMQTT{ + callbackByTopic: map[string]mqtt.MessageHandler{}, + payloadByTopic: map[string][]byte{}, + } + + catbus := &Client{ + mqtt: fakeMQTT, + payloadByTopic: map[string][]byte{}, + onconnectTimerByTopic: map[string]*time.Timer{}, + onconnectDelay: 1 * time.Millisecond, + onconnectJitter: 1, + } + if tt.payloadByTopic != nil { + catbus.payloadByTopic = tt.payloadByTopic + } + + for _, topic := range tt.subscribe { + catbus.Subscribe(topic, func(_ *Client, _ Message) {}) + } + for topic, message := range tt.receive { + fakeMQTT.send(topic, message.retention, message.payload) + } + + catbus.stopAllTimers() + catbus.startAllTimers() + + // TODO: replace with proper channel signaling or sth. + time.Sleep(1 * time.Second) + + got := fakeMQTT.payloadByTopic + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("[%d]: got %v, want %v", i, got, tt.want) + } + } +} + +type ( + fakeMQTT struct { + mqtt.Client + + callbackByTopic map[string]mqtt.MessageHandler + payloadByTopic map[string][]byte + } + + fakeMessage struct { + mqtt.Message + + topic string + retained bool + payload []byte + } + + fakeToken struct{} +) + +func (f *fakeMQTT) Publish(topic string, qos byte, retain bool, payload interface{}) mqtt.Token { + bytes, ok := payload.([]byte) + if !ok { + panic(fmt.Sprintf("expected type []byte, got %v", reflect.TypeOf(payload))) + } + + log.Printf("topic %q payload %s", topic, payload) + f.payloadByTopic[topic] = bytes + return &fakeToken{} +} +func (f *fakeMQTT) Subscribe(topic string, qos byte, callback mqtt.MessageHandler) mqtt.Token { + f.callbackByTopic[topic] = callback + + return &fakeToken{} +} +func (f *fakeMQTT) send(topic string, retention Retention, payload []byte) { + // if retention == Retain { + // f.payloadByTopic[topic] = payload + // } + + if callback, ok := f.callbackByTopic[topic]; ok { + msg := &fakeMessage{ + topic: topic, + retained: bool(retention), + payload: payload, + } + callback(f, msg) + } +} + +func (f *fakeMessage) Topic() string { + return f.topic +} +func (f *fakeMessage) Payload() []byte { + return f.payload +} +func (f *fakeMessage) Retained() bool { + return f.retained +} + +func (_ *fakeToken) Wait() bool { + return false +} +func (_ *fakeToken) WaitTimeout(_ time.Duration) bool { + return false +} +func (_ *fakeToken) Error() error { + return nil +} |