diff options
-rw-r--r-- | catbus.go | 33 | ||||
-rw-r--r-- | catbus_test.go | 111 |
2 files changed, 140 insertions, 4 deletions
@@ -19,6 +19,8 @@ type ( client struct { mqtt mqtt.Client + subscribeEveryMessage bool + payloadByTopicMu sync.Mutex payloadByTopic map[string]string @@ -27,12 +29,20 @@ type ( onconnectDelay time.Duration onconnectJitter time.Duration + + // syncCallbacks makes callbacks synchronous. + // ONLY FOR TESTING. + syncCallbacks bool } ClientOptions struct { DisconnectHandler func(Client, error) ConnectHandler func(Client) + // SubscribeEveryMessage determines if the Subscribe callback will be called for all incoming messages. + // When SubscribeEveryMessage is false, the Subscribe callback will only trigger when the value changes. + SubscribeEveryMessage bool + // Publish previously seen or default values on connecting after OnconnectDelay ± [0,OnconnectJitter). OnconnectDelay time.Duration OnconnectJitter time.Duration @@ -56,6 +66,8 @@ const ( func NewClient(brokerURI string, options ClientOptions) Client { client := &client{ + subscribeEveryMessage: options.SubscribeEveryMessage, + payloadByTopic: map[string]string{}, onconnectTimerByTopic: map[string]*time.Timer{}, @@ -107,10 +119,17 @@ func (c *client) Connect() error { // Subscribe subscribes to a Catbus MQTT topic. func (c *client) Subscribe(topic string, f MessageHandler) error { - return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, msg mqtt.Message) { - c.storePayload(msg.Topic(), Retention(msg.Retained()), string(msg.Payload())) - - go f(c, messageFromMQTTMessage(msg)) + return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, raw mqtt.Message) { + msg := messageFromMQTTMessage(raw) + + if c.subscribeEveryMessage || msg.Payload != c.payloadForTopic(topic) { + c.storePayload(msg.Topic, msg.Retention, msg.Payload) + if c.syncCallbacks { + f(c, msg) + } else { + go f(c, msg) + } + } }).Error() } @@ -129,6 +148,12 @@ func (c *client) jitteredOnconnectDelay() time.Duration { return c.onconnectDelay - jitter } +func (c *client) payloadForTopic(topic string) string { + c.payloadByTopicMu.Lock() + defer c.payloadByTopicMu.Unlock() + return c.payloadByTopic[topic] +} + func (c *client) storePayload(topic string, retention Retention, payload string) { c.payloadByTopicMu.Lock() defer c.payloadByTopicMu.Unlock() diff --git a/catbus_test.go b/catbus_test.go index b345c48..ecabcdd 100644 --- a/catbus_test.go +++ b/catbus_test.go @@ -21,6 +21,117 @@ type ( } ) +func TestSubscribe(t *testing.T) { + tests := []struct { + messages []Message + subscribeEveryMessage bool + want []Message + }{ + { + messages: []Message{ + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + }, + want: []Message{ + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + }, + }, + { + messages: []Message{ + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "off", + }, + }, + want: []Message{ + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "off", + }, + }, + }, + { + messages: []Message{ + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + }, + subscribeEveryMessage: true, + want: []Message{ + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + { + Topic: "home/tv/power", + Retention: Retain, + Payload: "on", + }, + }, + }, + } + + for i, tt := range tests { + fakeMQTT := &fakeMQTT{ + callbackByTopic: map[string]mqtt.MessageHandler{}, + payloadByTopic: map[string]string{}, + } + + catbus := &client{ + mqtt: fakeMQTT, + payloadByTopic: map[string]string{}, + subscribeEveryMessage: tt.subscribeEveryMessage, + + syncCallbacks: true, + } + + var got []Message + catbus.Subscribe("home/tv/power", func(_ Client, msg Message) { + got = append(got, msg) + }) + + for _, msg := range tt.messages { + fakeMQTT.send(msg.Topic, msg.Retention, msg.Payload) + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("[%d]: got %v, want %v", i, got, tt.want) + } + } +} + func TestOnConnect(t *testing.T) { tests := []struct { payloadByTopic map[string]string |