aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--catbus.go33
-rw-r--r--catbus_test.go111
2 files changed, 140 insertions, 4 deletions
diff --git a/catbus.go b/catbus.go
index 66cab2e..c8cf033 100644
--- a/catbus.go
+++ b/catbus.go
@@ -19,6 +19,8 @@ type (
client struct {
mqtt mqtt.Client
+ subscribeEveryMessage bool
+
payloadByTopicMu sync.Mutex
payloadByTopic map[string]string
@@ -27,12 +29,20 @@ type (
onconnectDelay time.Duration
onconnectJitter time.Duration
+
+ // syncCallbacks makes callbacks synchronous.
+ // ONLY FOR TESTING.
+ syncCallbacks bool
}
ClientOptions struct {
DisconnectHandler func(Client, error)
ConnectHandler func(Client)
+ // SubscribeEveryMessage determines if the Subscribe callback will be called for all incoming messages.
+ // When SubscribeEveryMessage is false, the Subscribe callback will only trigger when the value changes.
+ SubscribeEveryMessage bool
+
// Publish previously seen or default values on connecting after OnconnectDelay ± [0,OnconnectJitter).
OnconnectDelay time.Duration
OnconnectJitter time.Duration
@@ -56,6 +66,8 @@ const (
func NewClient(brokerURI string, options ClientOptions) Client {
client := &client{
+ subscribeEveryMessage: options.SubscribeEveryMessage,
+
payloadByTopic: map[string]string{},
onconnectTimerByTopic: map[string]*time.Timer{},
@@ -107,10 +119,17 @@ func (c *client) Connect() error {
// Subscribe subscribes to a Catbus MQTT topic.
func (c *client) Subscribe(topic string, f MessageHandler) error {
- return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, msg mqtt.Message) {
- c.storePayload(msg.Topic(), Retention(msg.Retained()), string(msg.Payload()))
-
- go f(c, messageFromMQTTMessage(msg))
+ return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, raw mqtt.Message) {
+ msg := messageFromMQTTMessage(raw)
+
+ if c.subscribeEveryMessage || msg.Payload != c.payloadForTopic(topic) {
+ c.storePayload(msg.Topic, msg.Retention, msg.Payload)
+ if c.syncCallbacks {
+ f(c, msg)
+ } else {
+ go f(c, msg)
+ }
+ }
}).Error()
}
@@ -129,6 +148,12 @@ func (c *client) jitteredOnconnectDelay() time.Duration {
return c.onconnectDelay - jitter
}
+func (c *client) payloadForTopic(topic string) string {
+ c.payloadByTopicMu.Lock()
+ defer c.payloadByTopicMu.Unlock()
+ return c.payloadByTopic[topic]
+}
+
func (c *client) storePayload(topic string, retention Retention, payload string) {
c.payloadByTopicMu.Lock()
defer c.payloadByTopicMu.Unlock()
diff --git a/catbus_test.go b/catbus_test.go
index b345c48..ecabcdd 100644
--- a/catbus_test.go
+++ b/catbus_test.go
@@ -21,6 +21,117 @@ type (
}
)
+func TestSubscribe(t *testing.T) {
+ tests := []struct {
+ messages []Message
+ subscribeEveryMessage bool
+ want []Message
+ }{
+ {
+ messages: []Message{
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ },
+ want: []Message{
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ },
+ },
+ {
+ messages: []Message{
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "off",
+ },
+ },
+ want: []Message{
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "off",
+ },
+ },
+ },
+ {
+ messages: []Message{
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ },
+ subscribeEveryMessage: true,
+ want: []Message{
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ {
+ Topic: "home/tv/power",
+ Retention: Retain,
+ Payload: "on",
+ },
+ },
+ },
+ }
+
+ for i, tt := range tests {
+ fakeMQTT := &fakeMQTT{
+ callbackByTopic: map[string]mqtt.MessageHandler{},
+ payloadByTopic: map[string]string{},
+ }
+
+ catbus := &client{
+ mqtt: fakeMQTT,
+ payloadByTopic: map[string]string{},
+ subscribeEveryMessage: tt.subscribeEveryMessage,
+
+ syncCallbacks: true,
+ }
+
+ var got []Message
+ catbus.Subscribe("home/tv/power", func(_ Client, msg Message) {
+ got = append(got, msg)
+ })
+
+ for _, msg := range tt.messages {
+ fakeMQTT.send(msg.Topic, msg.Retention, msg.Payload)
+ }
+
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("[%d]: got %v, want %v", i, got, tt.want)
+ }
+ }
+}
+
func TestOnConnect(t *testing.T) {
tests := []struct {
payloadByTopic map[string]string