aboutsummaryrefslogtreecommitdiff
path: root/catbus/catbus.go
diff options
context:
space:
mode:
Diffstat (limited to 'catbus/catbus.go')
-rw-r--r--catbus/catbus.go131
1 files changed, 88 insertions, 43 deletions
diff --git a/catbus/catbus.go b/catbus/catbus.go
index 2538a2f..bb674c6 100644
--- a/catbus/catbus.go
+++ b/catbus/catbus.go
@@ -20,24 +20,27 @@ type (
Client struct {
mqtt mqtt.Client
- rebroadcastMu sync.Mutex
- rebroadcastByTopic map[string]*time.Timer
- rebroadcastPeriod time.Duration
- rebroadcastJitter time.Duration
- }
+ payloadByTopicMu sync.Mutex
+ payloadByTopic map[string][]byte
+
+ onconnectTimerByTopicMu sync.Mutex
+ onconnectTimerByTopic map[string]*time.Timer
+ onconnectDelay time.Duration
+ onconnectJitter time.Duration
+ }
ClientOptions struct {
DisconnectHandler func(*Client, error)
ConnectHandler func(*Client)
- // Rebroadcast previously seen values every RebroadcastPeriod ± [0,RebroadcastJitter).
- RebroadcastPeriod time.Duration
- RebroadcastJitter time.Duration
+ // Publish previously seen or default values on connecting after OnconnectDelay ± [0,OnconnectJitter).
+ OnconnectDelay time.Duration
+ OnconnectJitter time.Duration
- // RebroadcastDefaults are optional values to seed rebroadcasting if no prior values are seen.
+ // DefaultPayloadByTopic are optional values to publish on connect if no prior values are seen.
// E.g. unless we've been told otherwise, assume a device is off.
- RebroadcastDefaults map[string][]byte
+ DefaultPayloadByTopic map[string][]byte
}
// Retention is whether or not the MQTT broker should retain the message.
@@ -56,41 +59,47 @@ const (
)
const (
- DefaultRebroadcastPeriod = 1 * time.Minute
- DefaultRebroadcastJitter = 15 * time.Second
+ DefaultOnconnectDelay = 1 * time.Minute
+ DefaultOnconnectJitter = 15 * time.Second
)
func NewClient(brokerURI string, options ClientOptions) *Client {
client := &Client{
- rebroadcastByTopic: map[string]*time.Timer{},
- rebroadcastPeriod: DefaultRebroadcastPeriod,
- rebroadcastJitter: DefaultRebroadcastJitter,
+ payloadByTopic: map[string][]byte{},
+ onconnectTimerByTopic: map[string]*time.Timer{},
+
+ onconnectDelay: DefaultOnconnectDelay,
+ onconnectJitter: DefaultOnconnectJitter,
}
- if options.RebroadcastPeriod != 0 {
- client.rebroadcastPeriod = options.RebroadcastPeriod
+ if options.OnconnectDelay != 0 {
+ client.onconnectDelay = options.OnconnectDelay
}
- if options.RebroadcastJitter != 0 {
- client.rebroadcastJitter = options.RebroadcastJitter
+ if options.OnconnectJitter != 0 {
+ client.onconnectJitter = options.OnconnectJitter
}
- for topic, payload := range options.RebroadcastDefaults {
- // TODO: Allow users to set retention?
- client.rebroadcastLater(topic, Retain, payload)
+ for topic, payload := range options.DefaultPayloadByTopic {
+ client.payloadByTopic[topic] = payload
}
mqttOpts := mqtt.NewClientOptions()
mqttOpts.AddBroker(brokerURI)
mqttOpts.SetAutoReconnect(true)
- mqttOpts.SetConnectionLostHandler(func(c mqtt.Client, err error) {
- if options.DisconnectHandler != nil {
- options.DisconnectHandler(client, err)
- }
- })
mqttOpts.SetOnConnectHandler(func(c mqtt.Client) {
+ client.stopAllTimers()
+ client.startAllTimers()
+
if options.ConnectHandler != nil {
options.ConnectHandler(client)
}
})
+ mqttOpts.SetConnectionLostHandler(func(c mqtt.Client, err error) {
+ client.stopAllTimers()
+
+ if options.DisconnectHandler != nil {
+ options.DisconnectHandler(client, err)
+ }
+ })
client.mqtt = mqtt.NewClient(mqttOpts)
return client
@@ -107,7 +116,7 @@ func (c *Client) Connect() error {
// Subscribe subscribes to a Catbus MQTT topic.
func (c *Client) Subscribe(topic string, f MessageHandler) error {
return c.mqtt.Subscribe(topic, atLeastOnce, func(_ mqtt.Client, msg mqtt.Message) {
- c.rebroadcastLater(msg.Topic(), Retention(msg.Retained()), msg.Payload())
+ c.storePayload(msg.Topic(), Retention(msg.Retained()), msg.Payload())
f(c, msg)
}).Error()
@@ -115,32 +124,68 @@ func (c *Client) Subscribe(topic string, f MessageHandler) error {
// Publish publishes to a Catbus MQTT topic.
func (c *Client) Publish(topic string, retention Retention, payload []byte) error {
- c.rebroadcastLater(topic, retention, payload)
+ c.storePayload(topic, retention, payload)
return c.mqtt.Publish(topic, atLeastOnce, bool(retention), payload).Error()
}
-func (c *Client) rebroadcastLater(topic string, retention Retention, payload []byte) {
- c.rebroadcastMu.Lock()
- defer c.rebroadcastMu.Unlock()
+func (c *Client) jitteredOnconnectDelay() time.Duration {
+ jitter := time.Duration(rand.Intn(int(c.onconnectJitter)))
+ if rand.Intn(2) == 0 {
+ return c.onconnectDelay + jitter
+ }
+ return c.onconnectDelay - jitter
+}
- if timer := c.rebroadcastByTopic[topic]; timer != nil {
- _ = timer.Stop()
+func (c *Client) storePayload(topic string, retention Retention, payload []byte) {
+ c.payloadByTopicMu.Lock()
+ defer c.payloadByTopicMu.Unlock()
+
+ if _, ok := c.payloadByTopic[topic]; !ok && retention == DontRetain {
+ // If we don't have a copy, and the sender doesn't want it retained, don't retain it.
+ return
}
+ c.stopTimer(topic)
+
if len(payload) == 0 {
- // No payload => remove => don't rebroadcast.
+ delete(c.payloadByTopic, topic)
return
}
+ c.payloadByTopic[topic] = payload
+}
+func (c *Client) stopTimer(topic string) {
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
- c.rebroadcastByTopic[topic] = time.AfterFunc(c.rebroadcastDuration(), func() {
- _ = c.Publish(topic, retention, payload)
- })
+ if timer, ok := c.onconnectTimerByTopic[topic]; ok {
+ _ = timer.Stop()
+ }
+}
+func (c *Client) stopAllTimers() {
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
+
+ for _, timer := range c.onconnectTimerByTopic {
+ _ = timer.Stop()
+ }
}
-func (c *Client) rebroadcastDuration() time.Duration {
- jitter := time.Duration(rand.Intn(int(c.rebroadcastJitter)))
- if rand.Intn(1) == 0 {
- return c.rebroadcastPeriod + jitter
+func (c *Client) startAllTimers() {
+ c.payloadByTopicMu.Lock()
+ defer c.payloadByTopicMu.Unlock()
+
+ c.onconnectTimerByTopicMu.Lock()
+ defer c.onconnectTimerByTopicMu.Unlock()
+
+ for topic := range c.payloadByTopic {
+ c.onconnectTimerByTopic[topic] = time.AfterFunc(c.jitteredOnconnectDelay(), func() {
+ c.payloadByTopicMu.Lock()
+ payload, ok := c.payloadByTopic[topic]
+ c.payloadByTopicMu.Unlock()
+ if !ok {
+ return
+ }
+ _ = c.Publish(topic, Retain, payload)
+ })
}
- return c.rebroadcastPeriod - jitter
}