From 61806a5fa2166890fbf77a00a6c0b784b4c32c4b Mon Sep 17 00:00:00 2001 From: Empire Phoenix Date: Mon, 27 Apr 2026 09:39:29 +0200 Subject: [PATCH] Add `mcutie` MQTT client implementation and improve library structure - Integrated `mcutie` library as a core MQTT client for device communication. - Added support for Home Assistant entities (binary sensor, button) via MQTT. - Implemented buffer management, async operations, and packet encoding/decoding. - Introduced structured error handling and device registration features. - Updated `Cargo.toml` with new dependencies and enabled feature flags for `serde` and `log`. - Enhanced logging macros with configurable options (`defmt` or `log`). - Organized codebase into modules (buffer, components, IO, publish, etc.) for better maintainability. --- .../rust/src/mcutie_3_0_0/Cargo.toml | 34 ++ .../MainBoard/rust/src/mcutie_3_0_0/buffer.rs | 124 +++++ .../MainBoard/rust/src/mcutie_3_0_0/fmt.rs | 80 +++ .../homeassistant/binary_sensor.rs | 120 +++++ .../src/mcutie_3_0_0/homeassistant/button.rs | 40 ++ .../src/mcutie_3_0_0/homeassistant/light.rs | 384 ++++++++++++++ .../src/mcutie_3_0_0/homeassistant/mod.rs | 295 +++++++++++ .../src/mcutie_3_0_0/homeassistant/sensor.rs | 103 ++++ .../src/mcutie_3_0_0/homeassistant/ser.rs | 333 ++++++++++++ .../MainBoard/rust/src/mcutie_3_0_0/io.rs | 483 ++++++++++++++++++ .../MainBoard/rust/src/mcutie_3_0_0/lib.rs | 227 ++++++++ .../MainBoard/rust/src/mcutie_3_0_0/pipe.rs | 267 ++++++++++ .../rust/src/mcutie_3_0_0/publish.rs | 173 +++++++ .../MainBoard/rust/src/mcutie_3_0_0/topic.rs | 284 ++++++++++ 14 files changed, 2947 insertions(+) create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/Cargo.toml create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/buffer.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/fmt.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/binary_sensor.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/button.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/light.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/mod.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/sensor.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/ser.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/io.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/lib.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/pipe.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/publish.rs create mode 100644 Software/MainBoard/rust/src/mcutie_3_0_0/topic.rs diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/Cargo.toml b/Software/MainBoard/rust/src/mcutie_3_0_0/Cargo.toml new file mode 100644 index 0000000..eb48eb1 --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "mcutie" +version = "3.0.0" +edition = "2021" + +[lib] +path = "lib.rs" + +[features] +default = [] +homeassistant = [] +serde = ["dep:serde", "heapless/serde"] +defmt = [] +log = ["dep:log"] + +[dependencies] +embassy-net = { version = "0.8.0", default-features = false, features = ["tcp", "dns", "proto-ipv4", "proto-ipv6", "medium-ethernet"] } +embassy-sync = { version = "0.8.0", default-features = false } +embassy-time = { version = "0.5.1", default-features = false } +embassy-futures = { version = "0.1.2", default-features = false } +embedded-io = { version = "0.7.1", default-features = false } +embedded-io-async = { version = "0.7.0", default-features = false } +heapless = { version = "0.7.17", default-features = false } +mqttrs = { version = "0.4.1", default-features = false } +once_cell = { version = "1.21.3", default-features = false, features = ["critical-section"] } +pin-project = { version = "1.1.10", default-features = false } +hex = { version = "0.4.3", default-features = false } +serde = { version = "1.0.228", default-features = false, features = ["derive"], optional = true } +log = { version = "0.4.28", default-features = false, optional = true } + +[dev-dependencies] +futures-executor = "0.3.31" +futures-timer = "3.0.3" +futures-util = "0.3.31" diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/buffer.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/buffer.rs new file mode 100644 index 0000000..2397f80 --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/buffer.rs @@ -0,0 +1,124 @@ +use core::{cmp, fmt, ops::Deref}; + +use embedded_io::{SliceWriteError, Write}; +use mqttrs::{encode_slice, Packet}; + +use crate::Error; + +/// A stack allocated buffer that can be written to and then read back from. +/// Dereferencing as a [`u8`] slice allows access to previously written data. +/// +/// Can be written to with [`write!`] and supports [`embedded_io::Write`] and +/// [`embedded_io_async::Write`]. +pub struct Buffer { + bytes: [u8; N], + cursor: usize, +} + +impl Default for Buffer { + fn default() -> Self { + Self::new() + } +} + +impl Buffer { + /// Creates a new buffer. + pub(crate) const fn new() -> Self { + Self { + bytes: [0; N], + cursor: 0, + } + } + + /// Creates a new buffer and writes the given data into it. + pub(crate) fn from(buf: &[u8]) -> Result { + let mut buffer = Self::new(); + match buffer.write_all(buf) { + Ok(()) => Ok(buffer), + Err(_) => Err(Error::TooLarge), + } + } + + pub(crate) fn encode_packet(&mut self, packet: &Packet<'_>) -> Result<(), mqttrs::Error> { + let len = encode_slice(packet, &mut self.bytes[self.cursor..])?; + self.cursor += len; + + Ok(()) + } + + #[cfg(feature = "serde")] + /// Serializes a value into this buffer using JSON. + pub(crate) fn serialize_json( + &mut self, + value: &T, + ) -> Result<(), serde_json_core::ser::Error> { + let len = serde_json_core::to_slice(value, &mut self.bytes[self.cursor..])?; + self.cursor += len; + + Ok(()) + } + + #[cfg(feature = "serde")] + /// Deserializes this buffer using JSON into the given type. + pub fn deserialize_json<'a, T: serde::Deserialize<'a>>( + &'a self, + ) -> Result { + let (result, _) = serde_json_core::from_slice(self)?; + + Ok(result) + } + + /// The number of bytes available for writing into this buffer. + pub fn available(&self) -> usize { + N - self.cursor + } +} + +impl Deref for Buffer { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.bytes[0..self.cursor] + } +} + +impl fmt::Write for Buffer { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.write_all(s.as_bytes()).map_err(|_| fmt::Error) + } +} + +impl embedded_io::ErrorType for Buffer { + type Error = SliceWriteError; +} + +impl embedded_io::Write for Buffer { + fn write(&mut self, buf: &[u8]) -> Result { + if buf.is_empty() { + return Ok(0); + } + + let writable = cmp::min(self.available(), buf.len()); + if writable == 0 { + Err(SliceWriteError::Full) + } else { + self.bytes[self.cursor..self.cursor + writable].copy_from_slice(buf); + self.cursor += writable; + Ok(writable) + } + } + + fn flush(&mut self) -> Result<(), Self::Error> { + Ok(()) + } +} + +impl embedded_io_async::Write for Buffer { + async fn write(&mut self, buf: &[u8]) -> Result { + ::write(self, buf) + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + Ok(()) + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/fmt.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/fmt.rs new file mode 100644 index 0000000..b678fbf --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/fmt.rs @@ -0,0 +1,80 @@ +#![macro_use] + +#[cfg(all(feature = "defmt", feature = "log"))] +compile_error!("The `defmt` and `log` features cannot both be enabled at the same time."); + +#[cfg(not(feature = "defmt"))] +use core::fmt; + +#[cfg(feature = "defmt")] +pub(crate) use ::defmt::Debug2Format; + +#[cfg(not(feature = "defmt"))] +pub(crate) struct Debug2Format(pub(crate) D); + +#[cfg(feature = "log")] +impl fmt::Debug for Debug2Format { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[collapse_debuginfo(yes)] +macro_rules! trace { + ($s:literal $(, $x:expr)* $(,)?) => { + #[cfg(feature = "defmt")] + ::defmt::trace!($s $(, $x)*); + #[cfg(feature = "log")] + ::log::trace!($s $(, $x)*); + #[cfg(not(any(feature="defmt", feature="log")))] + let _ = ($( & $x ),*); + }; +} + +#[collapse_debuginfo(yes)] +macro_rules! debug { + ($s:literal $(, $x:expr)* $(,)?) => { + #[cfg(feature = "defmt")] + ::defmt::debug!($s $(, $x)*); + #[cfg(feature = "log")] + ::log::debug!($s $(, $x)*); + #[cfg(not(any(feature="defmt", feature="log")))] + let _ = ($( & $x ),*); + }; +} + +#[collapse_debuginfo(yes)] +macro_rules! info { + ($s:literal $(, $x:expr)* $(,)?) => { + #[cfg(feature = "defmt")] + ::defmt::info!($s $(, $x)*); + #[cfg(feature = "log")] + ::log::info!($s $(, $x)*); + #[cfg(not(any(feature="defmt", feature="log")))] + let _ = ($( & $x ),*); + }; +} + +#[collapse_debuginfo(yes)] +macro_rules! warn { + ($s:literal $(, $x:expr)* $(,)?) => { + #[cfg(feature = "defmt")] + ::defmt::warn!($s $(, $x)*); + #[cfg(feature = "log")] + ::log::warn!($s $(, $x)*); + #[cfg(not(any(feature="defmt", feature="log")))] + let _ = ($( & $x ),*); + }; +} + +#[collapse_debuginfo(yes)] +macro_rules! error { + ($s:literal $(, $x:expr)* $(,)?) => { + #[cfg(feature = "defmt")] + ::defmt::error!($s $(, $x)*); + #[cfg(feature = "log")] + ::log::error!($s $(, $x)*); + #[cfg(not(any(feature="defmt", feature="log")))] + let _ = ($( & $x ),*); + }; +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/binary_sensor.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/binary_sensor.rs new file mode 100644 index 0000000..b62d2fd --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/binary_sensor.rs @@ -0,0 +1,120 @@ +//! Tools for publishing a [Home Assistant binary sensor](https://www.home-assistant.io/integrations/binary_sensor.mqtt/). +use core::ops::Deref; + +use serde::{Deserialize, Serialize}; + +use crate::{homeassistant::Component, Error, Publishable, Topic}; + +/// The state of the sensor. Can be easily converted to or from a [`bool`]. +#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(from = "&str", into = "&'static str")] +#[allow(missing_docs)] +pub enum BinarySensorState { + On, + Off, +} + +impl From for &'static str { + fn from(state: BinarySensorState) -> Self { + match state { + BinarySensorState::On => "ON", + BinarySensorState::Off => "OFF", + } + } +} + +impl<'a> From<&'a str> for BinarySensorState { + fn from(st: &'a str) -> Self { + if st == "ON" { + Self::On + } else { + Self::Off + } + } +} + +impl From for BinarySensorState { + fn from(val: bool) -> Self { + if val { + BinarySensorState::On + } else { + BinarySensorState::Off + } + } +} + +impl From for bool { + fn from(val: BinarySensorState) -> Self { + match val { + BinarySensorState::On => true, + BinarySensorState::Off => true, + } + } +} + +impl AsRef<[u8]> for BinarySensorState { + fn as_ref(&self) -> &'static [u8] { + match self { + Self::On => "ON".as_bytes(), + Self::Off => "OFF".as_bytes(), + } + } +} + +/// The type of sensor. +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +#[allow(missing_docs)] +pub enum BinarySensorClass { + Battery, + BatteryCharging, + CarbonMonoxide, + Cold, + Connectivity, + Door, + GarageDoor, + Gas, + Heat, + Light, + Lock, + Moisture, + Motion, + Moving, + Occupancy, + Opening, + Plug, + Power, + Presence, + Problem, + Running, + Safety, + Smoke, + Sound, + Tamper, + Update, + Vibration, + Window, +} + +/// A binary sensor that can publish a [`BinarySensorState`] status. +#[derive(Serialize)] +pub struct BinarySensor { + /// The type of sensor + pub device_class: Option, +} + +impl Component for BinarySensor { + type State = BinarySensorState; + + fn platform() -> &'static str { + "binary_sensor" + } + + async fn publish_state>( + &self, + topic: &Topic, + state: Self::State, + ) -> Result<(), Error> { + topic.with_bytes(state).publish().await + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/button.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/button.rs new file mode 100644 index 0000000..b19a66f --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/button.rs @@ -0,0 +1,40 @@ +//! Tools for publishing a [Home Assistant button](https://www.home-assistant.io/integrations/button.mqtt/). +use core::ops::Deref; + +use serde::Serialize; + +use crate::{homeassistant::Component, Error, Topic}; + +/// The type of button. +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +#[allow(missing_docs)] +pub enum ButtonClass { + Identify, + Restart, + Update, +} + +/// A button that can be pressed. +#[derive(Serialize)] +pub struct Button { + /// The type of button. + pub device_class: Option, +} + +impl Component for Button { + type State = (); + + fn platform() -> &'static str { + "button" + } + + async fn publish_state>( + &self, + _topic: &Topic, + _state: Self::State, + ) -> Result<(), Error> { + // Buttons don't have a state + Err(Error::Invalid) + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/light.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/light.rs new file mode 100644 index 0000000..1b85d11 --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/light.rs @@ -0,0 +1,384 @@ +//! Tools for publishing a [Home Assistant light](https://www.home-assistant.io/integrations/light.mqtt/). +use core::{ops::Deref, str}; + +use serde::{ser::SerializeStruct, Deserialize, Serialize, Serializer}; + +use crate::{ + fmt::Debug2Format, + homeassistant::{binary_sensor::BinarySensorState, ser::List, Component}, + Error, Payload, Publishable, Topic, +}; + +#[derive(Serialize)] +#[serde(rename_all = "lowercase")] +#[allow(missing_docs)] +pub enum SupportedColorMode { + OnOff, + Brightness, + #[serde(rename = "color_temp")] + ColorTemp, + Hs, + Xy, + Rgb, + Rgbw, + Rgbww, + White, +} + +#[derive(Serialize, Deserialize, Default)] +struct SerializedColor { + #[serde(default, skip_serializing_if = "Option::is_none")] + h: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + s: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + x: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + y: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + r: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + g: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + b: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + w: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + c: Option, +} + +#[derive(Deserialize)] +struct LedPayload<'a> { + state: BinarySensorState, + #[serde(default)] + brightness: Option, + #[serde(default)] + color_temp: Option, + #[serde(default)] + color: Option, + #[serde(default)] + effect: Option<&'a str>, +} + +/// The color of the light in various forms. +#[derive(Serialize)] +#[serde(rename_all = "lowercase", tag = "color_mode", content = "color")] +#[allow(missing_docs)] +pub enum Color { + None, + Brightness(u8), + ColorTemp(u32), + Hs { + #[serde(rename = "h")] + hue: f32, + #[serde(rename = "s")] + saturation: f32, + }, + Xy { + x: f32, + y: f32, + }, + Rgb { + #[serde(rename = "r")] + red: u8, + #[serde(rename = "g")] + green: u8, + #[serde(rename = "b")] + blue: u8, + }, + Rgbw { + #[serde(rename = "r")] + red: u8, + #[serde(rename = "g")] + green: u8, + #[serde(rename = "b")] + blue: u8, + #[serde(rename = "w")] + white: u8, + }, + Rgbww { + #[serde(rename = "r")] + red: u8, + #[serde(rename = "g")] + green: u8, + #[serde(rename = "b")] + blue: u8, + #[serde(rename = "c")] + cool_white: u8, + #[serde(rename = "w")] + warm_white: u8, + }, +} + +/// The state of the light. This can be sent to the broker and received as a +/// command from Home Assistant. +pub struct LightState<'a> { + /// Whether the light is on or off. + pub state: BinarySensorState, + /// The color of the light. + pub color: Color, + /// Any effect that is applied. + pub effect: Option<&'a str>, +} + +impl<'a> LightState<'a> { + /// Parses the state from a command payload. + pub fn from_payload(payload: &'a Payload) -> Result { + let parsed: LedPayload<'a> = match payload.deserialize_json() { + Ok(p) => p, + Err(e) => { + warn!("Failed to deserialize packet: {:?}", Debug2Format(&e)); + if let Ok(s) = str::from_utf8(payload) { + trace!("{}", s); + } + return Err(Error::PacketError); + } + }; + + let color = if let Some(color) = parsed.color { + if let Some(x) = color.x { + Color::Xy { + x, + y: color.y.unwrap_or_default(), + } + } else if let Some(h) = color.h { + Color::Hs { + hue: h, + saturation: color.s.unwrap_or_default(), + } + } else if let Some(c) = color.c { + Color::Rgbww { + red: color.r.unwrap_or_default(), + green: color.g.unwrap_or_default(), + blue: color.b.unwrap_or_default(), + cool_white: c, + warm_white: color.w.unwrap_or_default(), + } + } else if let Some(w) = color.w { + Color::Rgbw { + red: color.r.unwrap_or_default(), + green: color.g.unwrap_or_default(), + blue: color.b.unwrap_or_default(), + white: w, + } + } else { + Color::Rgb { + red: color.r.unwrap_or_default(), + green: color.g.unwrap_or_default(), + blue: color.b.unwrap_or_default(), + } + } + } else if let Some(color_temp) = parsed.color_temp { + Color::ColorTemp(color_temp) + } else if let Some(brightness) = parsed.brightness { + Color::Brightness(brightness) + } else { + Color::None + }; + + Ok(LightState { + state: parsed.state, + color, + effect: parsed.effect, + }) + } +} + +impl Serialize for LightState<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut len = 1; + + if self.effect.is_some() { + len += 1; + } + + match self.color { + Color::None => {} + Color::Brightness(_) | Color::ColorTemp(_) => len += 1, + _ => len += 2, + } + + let mut serializer = serializer.serialize_struct("LightState", len)?; + + serializer.serialize_field("state", &self.state)?; + + if let Some(effect) = self.effect { + serializer.serialize_field("effect", effect)?; + } else { + serializer.skip_field("effect")?; + } + + match self.color { + Color::None => { + serializer.skip_field("brightness")?; + serializer.skip_field("color_temp")?; + serializer.skip_field("color")?; + } + Color::Brightness(b) => { + serializer.skip_field("color_temp")?; + serializer.skip_field("color")?; + + serializer.serialize_field("brightness", &b)? + } + Color::ColorTemp(c) => { + serializer.skip_field("brightness")?; + serializer.skip_field("color")?; + + serializer.serialize_field("color_temp", &c)? + } + Color::Hs { hue, saturation } => { + serializer.skip_field("brightness")?; + serializer.skip_field("color_temp")?; + + serializer.serialize_field("color_mode", "hs")?; + + let color = SerializedColor { + h: Some(hue), + s: Some(saturation), + ..Default::default() + }; + + serializer.serialize_field("color", &color)? + } + Color::Xy { x, y } => { + serializer.skip_field("brightness")?; + serializer.skip_field("color_temp")?; + + serializer.serialize_field("color_mode", "xy")?; + + let color = SerializedColor { + x: Some(x), + y: Some(y), + ..Default::default() + }; + + serializer.serialize_field("color", &color)? + } + Color::Rgb { red, green, blue } => { + serializer.skip_field("brightness")?; + serializer.skip_field("color_temp")?; + + serializer.serialize_field("color_mode", "rgb")?; + + let color = SerializedColor { + r: Some(red), + g: Some(green), + b: Some(blue), + ..Default::default() + }; + + serializer.serialize_field("color", &color)? + } + Color::Rgbw { + red, + green, + blue, + white, + } => { + serializer.skip_field("brightness")?; + serializer.skip_field("color_temp")?; + + serializer.serialize_field("color_mode", "rgbw")?; + + let color = SerializedColor { + r: Some(red), + g: Some(green), + b: Some(blue), + w: Some(white), + ..Default::default() + }; + + serializer.serialize_field("color", &color)? + } + Color::Rgbww { + red, + green, + blue, + cool_white, + warm_white, + } => { + serializer.skip_field("brightness")?; + serializer.skip_field("color_temp")?; + + serializer.serialize_field("color_mode", "rgbww")?; + + let color = SerializedColor { + r: Some(red), + g: Some(green), + b: Some(blue), + c: Some(cool_white), + w: Some(warm_white), + ..Default::default() + }; + + serializer.serialize_field("color", &color)? + } + } + + serializer.end() + } +} + +/// A light entity +pub struct Light<'a, const C: usize, const E: usize> { + /// The color modes supported by the light. + pub supported_color_modes: [SupportedColorMode; C], + /// Any effects that can be used. + pub effects: [&'a str; E], +} + +impl Serialize for Light<'_, C, E> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut len = 2; + + if C > 0 { + len += 1; + } + + if E > 0 { + len += 2; + } + + let mut serializer = serializer.serialize_struct("Light", len)?; + + serializer.serialize_field("schema", "json")?; + + if C > 0 { + serializer.serialize_field("sup_clrm", &List::new(&self.supported_color_modes))?; + } else { + serializer.skip_field("sup_clrm")?; + } + + if E > 0 { + serializer.serialize_field("effect", &true)?; + serializer.serialize_field("fx_list", &List::new(&self.effects))?; + } else { + serializer.skip_field("effect")?; + serializer.skip_field("fx_list")?; + } + + serializer.end() + } +} + +impl Component for Light<'_, C, E> { + type State = LightState<'static>; + + fn platform() -> &'static str { + "light" + } + + async fn publish_state>( + &self, + topic: &Topic, + state: Self::State, + ) -> Result<(), Error> { + topic.with_json(state).publish().await + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/mod.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/mod.rs new file mode 100644 index 0000000..8d98205 --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/mod.rs @@ -0,0 +1,295 @@ +//! Home Assistant auto-discovery and related messages. +//! +//! Normally you would declare your entities statically in your binary. It is +//! then trivial to send out discovery messages or state changes. +//! +//! ``` +//! # use mcutie::{Publishable, Topic}; +//! # use mcutie::homeassistant::{Entity, Device, Origin, AvailabilityState, AvailabilityTopics}; +//! # use mcutie::homeassistant::binary_sensor::{BinarySensor, BinarySensorClass, BinarySensorState}; +//! const DEVICE_AVAILABILITY_TOPIC: Topic<&'static str> = Topic::Device("status"); +//! const MOTION_STATE_TOPIC: Topic<&'static str> = Topic::Device("motion/status"); +//! +//! const DEVICE: Device<'static> = Device::new(); +//! const ORIGIN: Origin<'static> = Origin::new(); +//! +//! const MOTION_SENSOR: Entity<'static, 1, BinarySensor> = Entity { +//! device: DEVICE, +//! origin: ORIGIN, +//! object_id: "motion", +//! unique_id: Some("motion"), +//! name: "Motion", +//! availability: AvailabilityTopics::All([DEVICE_AVAILABILITY_TOPIC]), +//! state_topic: Some(MOTION_STATE_TOPIC), +//! command_topic: None, +//! component: BinarySensor { +//! device_class: Some(BinarySensorClass::Motion), +//! }, +//! }; +//! +//! async fn send_discovery_messages() { +//! MOTION_SENSOR.publish_discovery().await.unwrap(); +//! DEVICE_AVAILABILITY_TOPIC.with_bytes(AvailabilityState::Online).publish().await.unwrap(); +//! } +//! +//! async fn send_state(state: BinarySensorState) { +//! MOTION_SENSOR.publish_state(state).await.unwrap(); +//! } +//! ``` +use core::{future::Future, ops::Deref}; + +use mqttrs::QoS; +use serde::{ + ser::{Error as _, SerializeStruct}, + Serialize, Serializer, +}; + +use crate::{ + device_id, device_type, homeassistant::ser::DiscoverySerializer, io::publish, Error, + McutieTask, MqttMessage, Payload, Publishable, Topic, TopicString, DATA_CHANNEL, +}; + +pub mod binary_sensor; +pub mod button; +pub mod light; +pub mod sensor; +mod ser; + +const HA_STATUS_TOPIC: Topic<&'static str> = Topic::General("homeassistant/status"); +const STATE_ONLINE: &str = "online"; +const STATE_OFFLINE: &str = "offline"; + +/// A trait representing a specific type of entity in Home Assistant +pub trait Component: Serialize { + /// The state to publish. + type State; + + /// The platform identifier for this entity. Internal. + fn platform() -> &'static str; + + /// Publishes this entity's state to the MQTT broker. + fn publish_state>( + &self, + topic: &Topic, + state: Self::State, + ) -> impl Future>; +} + +impl<'t, T, L, const S: usize> McutieTask<'t, T, L, S> +where + T: Deref + 't, + L: Publishable + 't, +{ + pub(super) async fn ha_after_connected(&self) { + let _ = HA_STATUS_TOPIC.subscribe(false).await; + } + + pub(super) async fn ha_handle_update( + &self, + topic: &Topic, + payload: &Payload, + ) -> bool { + if topic == &HA_STATUS_TOPIC { + if payload.as_ref() == STATE_ONLINE.as_bytes() { + DATA_CHANNEL.send(MqttMessage::HomeAssistantOnline).await; + } + + true + } else { + false + } + } +} + +impl> Serialize for Topic { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut topic = TopicString::new(); + self.to_string(&mut topic) + .map_err(|_| S::Error::custom("topic was too large to serialize"))?; + serializer.serialize_str(&topic) + } +} + +fn name_or_device(name: &Option<&str>, serializer: S) -> Result +where + S: Serializer, +{ + serializer.serialize_str(name.unwrap_or_else(|| device_type())) +} + +/// Represents the device in Home Assistant. +/// +/// Can just be the default in which case useful properties such as the ID are +/// automatically included. +#[derive(Clone, Copy, Default)] +pub struct Device<'a> { + /// A name to identify the device. If not provided the default device type is + /// used. + pub name: Option<&'a str>, + /// An optional configuration URL for the device. + pub configuration_url: Option<&'a str>, +} + +impl Device<'_> { + /// Creates a new default device. + pub const fn new() -> Self { + Self { + name: None, + configuration_url: None, + } + } +} + +impl Serialize for Device<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut len = 2; + if self.configuration_url.is_some() { + len += 1; + } + + let mut serializer = serializer.serialize_struct("Device", len)?; + + serializer.serialize_field("name", self.name.unwrap_or_else(|| device_type()))?; + serializer.serialize_field("ids", device_id())?; + + if let Some(cu) = self.configuration_url { + serializer.serialize_field("cu", cu)?; + } else { + serializer.skip_field("cu")?; + } + + serializer.end() + } +} + +/// Represents the device's origin in Home Assistant. +/// +/// Can just be the default in which case useful properties are automatically +/// included. +#[derive(Clone, Copy, Default, Serialize)] +pub struct Origin<'a> { + /// A name to identify the device's origin. If not provided the default + /// device type is used. + #[serde(serialize_with = "name_or_device")] + pub name: Option<&'a str>, +} + +impl Origin<'_> { + /// Creates a new default origin. + pub const fn new() -> Self { + Self { name: None } + } +} + +/// A single entity for Home Assistant. +/// +/// Calling [`Entity::publish_discovery`] will publish the discovery message to +/// allow Home Assistant to detect this entity. Read the +/// [Home Assistant MQTT docs](https://www.home-assistant.io/integrations/mqtt/) +/// for information on what some of these properties mean. +pub struct Entity<'a, const A: usize, C: Component> { + /// The device this entity is a part of. + pub device: Device<'a>, + /// The origin of the device. + pub origin: Origin<'a>, + /// An object identifier to allow for entity ID customisation in Home Assistant. + pub object_id: &'a str, + /// An optional unique identifier for the entity. + pub unique_id: Option<&'a str>, + /// A friendly name for the entity. + pub name: &'a str, + /// Specifies the availability topics that Home Assistant will listen to to + /// determine this entity's availability. + pub availability: AvailabilityTopics<'a, A>, + /// The state topic that this entity's state is published to. + pub state_topic: Option>, + /// The command topic that this entity receives commands from. + pub command_topic: Option>, + /// The specific entity. + pub component: C, +} + +impl Entity<'_, A, C> { + /// Publishes the discovery message for this entity to the broker. + pub async fn publish_discovery(&self) -> Result<(), Error> { + let mut topic = TopicString::new(); + topic + .push_str(option_env!("HA_DISCOVERY_PREFIX").unwrap_or("homeassistant")) + .map_err(|_| Error::TooLarge)?; + topic.push('/').map_err(|_| Error::TooLarge)?; + topic.push_str(C::platform()).map_err(|_| Error::TooLarge)?; + topic.push('/').map_err(|_| Error::TooLarge)?; + topic + .push_str(self.object_id) + .map_err(|_| Error::TooLarge)?; + topic.push_str("/config").map_err(|_| Error::TooLarge)?; + + let mut payload = Payload::new(); + payload.serialize_json(self).map_err(|_| Error::TooLarge)?; + + publish(&topic, &payload, QoS::AtMostOnce, false).await + } + + /// Publishes this entity's state to the broker. + /// + /// # Errors + /// + /// - [`Error::Invalid`] if the entity doesn't have a state topic. + pub async fn publish_state(&self, state: C::State) -> Result<(), Error> { + if let Some(topic) = self.state_topic { + self.component.publish_state(&topic, state).await + } else { + Err(Error::Invalid) + } + } +} + +/// A payload representing a device or entity's availability. +#[allow(missing_docs)] +pub enum AvailabilityState { + Online, + Offline, +} + +impl AsRef<[u8]> for AvailabilityState { + fn as_ref(&self) -> &'static [u8] { + match self { + Self::Online => STATE_ONLINE.as_bytes(), + Self::Offline => STATE_OFFLINE.as_bytes(), + } + } +} + +/// The availiabity topics that home assistant will use to determine an entity's +/// availability. +pub enum AvailabilityTopics<'a, const A: usize> { + /// The entity is always available. + None, + /// The entity is available if all of the topics are publishes as online. + All([Topic<&'a str>; A]), + /// The entity is available if any of the topics are publishes as online. + Any([Topic<&'a str>; A]), + /// The entity is available based on the most recent of the topics to + /// publish state. + Latest([Topic<&'a str>; A]), +} + +impl Serialize for Entity<'_, A, C> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let outer = DiscoverySerializer { + discovery: self, + inner: serializer, + }; + + self.component.serialize(outer) + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/sensor.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/sensor.rs new file mode 100644 index 0000000..bf48ecb --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/sensor.rs @@ -0,0 +1,103 @@ +//! Tools for publishing a [Home Assistant sensor](https://www.home-assistant.io/integrations/sensor.mqtt/). +use core::ops::Deref; + +use serde::Serialize; + +use crate::{homeassistant::Component, Error, Publishable, Topic}; + +/// The type of sensor. +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +#[allow(missing_docs)] +pub enum SensorClass { + ApparentPower, + Aqi, + AtmosphericPressure, + Battery, + CarbonDioxide, + CarbonMonoxide, + Current, + DataRate, + DataSize, + Date, + Distance, + Duration, + Energy, + EnergyStorage, + Enum, + Frequency, + Gas, + Humidity, + Illuminance, + Irradiance, + Moisture, + Monetary, + NitrogenDioxide, + NitrogenMonoxide, + NitrousOxide, + Ozone, + Ph, + Pm1, + Pm25, + Pm10, + PowerFactor, + Power, + Precipitation, + PrecipitationIntensity, + Pressure, + ReactivePower, + SignalStrength, + SoundPressure, + Speed, + SulphurDioxide, + Temperature, + Timestamp, + VolatileOrganicCompounds, + VolatileOrganicCompoundsParts, + Voltage, + Volume, + VolumeFlowRate, + VolumeStorage, + Water, + Weight, + WindSpeed, +} + +/// The type of measurement that this entity publishes. +#[derive(Serialize)] +#[serde(rename_all = "snake_case")] +pub enum SensorStateClass { + /// A measurement at a singe point in time. + Measurement, + /// A cumulative total that can increase or decrease over time. + Total, + /// A cumulative total that can only increase. + TotalIncreasing, +} + +/// A binary sensor that can publish a [`f32`] value. +#[derive(Serialize)] +pub struct Sensor<'u> { + /// The type of sensor. + pub device_class: Option, + /// The type of measurement that this sensor reports. + pub state_class: Option, + /// The unit of measurement for this sensor. + pub unit_of_measurement: Option<&'u str>, +} + +impl Component for Sensor<'_> { + type State = f32; + + fn platform() -> &'static str { + "sensor" + } + + async fn publish_state>( + &self, + topic: &Topic, + state: Self::State, + ) -> Result<(), Error> { + topic.with_display(state).publish().await + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/ser.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/ser.rs new file mode 100644 index 0000000..80e99d9 --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/homeassistant/ser.rs @@ -0,0 +1,333 @@ +use core::ops::Deref; + +use serde::{ + ser::{SerializeSeq, SerializeStruct}, + Serialize, Serializer, +}; + +use crate::{ + homeassistant::{AvailabilityTopics, Component, Entity}, + Topic, +}; + +#[derive(Serialize)] +pub(super) struct AvailabilityTopicItem<'a> { + topic: Topic<&'a str>, +} + +struct AvailabilityTopicList<'a, T: Deref, const N: usize> { + list: &'a [Topic; N], +} + +impl<'a, const N: usize, T: Deref> AvailabilityTopicList<'a, T, N> { + pub(super) fn new(list: &'a [Topic; N]) -> Self { + Self { list } + } +} + +impl, const N: usize> Serialize for AvailabilityTopicList<'_, T, N> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut serializer = serializer.serialize_seq(Some(N))?; + + for topic in self.list { + serializer.serialize_element(&AvailabilityTopicItem { + topic: topic.as_ref(), + })?; + } + + serializer.end() + } +} + +pub(super) struct List<'a, T: Serialize, const N: usize> { + list: &'a [T; N], +} + +impl<'a, T: Serialize, const N: usize> List<'a, T, N> { + pub(super) fn new(list: &'a [T; N]) -> Self { + Self { list } + } +} + +impl Serialize for List<'_, T, N> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut serializer = serializer.serialize_seq(Some(N))?; + + for item in self.list { + serializer.serialize_element(item)?; + } + + serializer.end() + } +} + +pub(super) struct DiscoverySerializer<'a, const A: usize, C: Component, S: Serializer> { + pub(super) discovery: &'a Entity<'a, A, C>, + pub(super) inner: S, +} + +impl Serializer for DiscoverySerializer<'_, A, C, S> { + type Ok = S::Ok; + type Error = S::Error; + type SerializeSeq = S::SerializeSeq; + type SerializeTuple = S::SerializeTuple; + type SerializeTupleStruct = S::SerializeTupleStruct; + type SerializeTupleVariant = S::SerializeTupleVariant; + type SerializeMap = S::SerializeMap; + type SerializeStruct = S::SerializeStruct; + type SerializeStructVariant = S::SerializeStructVariant; + + fn serialize_struct( + self, + name: &'static str, + mut len: usize, + ) -> Result { + len += 5; + if self.discovery.state_topic.is_some() { + len += 1; + } + if self.discovery.command_topic.is_some() { + len += 1; + } + if self.discovery.unique_id.is_some() { + len += 1; + } + if !matches!(self.discovery.availability, AvailabilityTopics::None) { + len += 2; + } + + let mut serializer = self.inner.serialize_struct(name, len)?; + + serializer.serialize_field("dev", &self.discovery.device)?; + serializer.serialize_field("o", &self.discovery.origin)?; + serializer.serialize_field("p", C::platform())?; + serializer.serialize_field("obj_id", self.discovery.object_id)?; + + serializer.serialize_field("name", self.discovery.name)?; + + if let Some(t) = self.discovery.state_topic { + serializer.serialize_field("stat_t", &t)?; + } else { + serializer.skip_field("stat_t")?; + } + + if let Some(t) = self.discovery.command_topic { + serializer.serialize_field("cmd_t", &t)?; + } else { + serializer.skip_field("cmd_t")?; + } + + match &self.discovery.availability { + AvailabilityTopics::None => { + serializer.skip_field("avty")?; + serializer.skip_field("avty_mode")?; + } + AvailabilityTopics::All(topics) => { + serializer.serialize_field("avty_mode", "all")?; + serializer.serialize_field("avty", &AvailabilityTopicList::new(topics))?; + } + AvailabilityTopics::Any(topics) => { + serializer.serialize_field("avty_mode", "any")?; + serializer.serialize_field("avty", &AvailabilityTopicList::new(topics))?; + } + AvailabilityTopics::Latest(topics) => { + serializer.serialize_field("avty_mode", "latest")?; + serializer.serialize_field("avty", &AvailabilityTopicList::new(topics))?; + } + } + + if let Some(v) = self.discovery.unique_id { + serializer.serialize_field("uniq_id", v)?; + } else { + serializer.skip_field("uniq_id")?; + } + + Ok(serializer) + } + + fn serialize_bool(self, _: bool) -> Result { + unimplemented!() + } + + fn serialize_i8(self, _: i8) -> Result { + unimplemented!() + } + + fn serialize_i16(self, _: i16) -> Result { + unimplemented!() + } + + fn serialize_i32(self, _: i32) -> Result { + unimplemented!() + } + + fn serialize_i64(self, _: i64) -> Result { + unimplemented!() + } + + fn serialize_u8(self, _: u8) -> Result { + unimplemented!() + } + + fn serialize_u16(self, _: u16) -> Result { + unimplemented!() + } + + fn serialize_u32(self, _: u32) -> Result { + unimplemented!() + } + + fn serialize_u64(self, _: u64) -> Result { + unimplemented!() + } + + fn serialize_f32(self, _: f32) -> Result { + unimplemented!() + } + + fn serialize_f64(self, _: f64) -> Result { + unimplemented!() + } + + fn serialize_char(self, _: char) -> Result { + unimplemented!() + } + + fn serialize_str(self, _: &str) -> Result { + unimplemented!() + } + + fn serialize_bytes(self, _: &[u8]) -> Result { + unimplemented!() + } + + fn serialize_none(self) -> Result { + unimplemented!() + } + + fn serialize_some(self, _: &T) -> Result + where + T: ?Sized + Serialize, + { + unimplemented!() + } + + fn serialize_unit(self) -> Result { + unimplemented!() + } + + fn serialize_unit_struct(self, _: &'static str) -> Result { + unimplemented!() + } + + fn serialize_unit_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + ) -> Result { + unimplemented!() + } + + fn serialize_newtype_struct(self, _: &'static str, _: &T) -> Result + where + T: ?Sized + Serialize, + { + unimplemented!() + } + + fn serialize_newtype_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + unimplemented!() + } + + fn serialize_seq(self, _: Option) -> Result { + unimplemented!() + } + + fn serialize_tuple(self, _: usize) -> Result { + unimplemented!() + } + + fn serialize_tuple_struct( + self, + _: &'static str, + _: usize, + ) -> Result { + unimplemented!() + } + + fn serialize_tuple_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: usize, + ) -> Result { + unimplemented!() + } + + fn serialize_map(self, _: Option) -> Result { + unimplemented!() + } + + fn serialize_struct_variant( + self, + _: &'static str, + _: u32, + _: &'static str, + _: usize, + ) -> Result { + unimplemented!() + } + + fn serialize_i128(self, _: i128) -> Result { + unimplemented!() + } + + fn serialize_u128(self, _: u128) -> Result { + unimplemented!() + } + + fn collect_seq(self, _: I) -> Result + where + I: IntoIterator, + ::Item: Serialize, + { + unimplemented!() + } + + fn collect_map(self, _: I) -> Result + where + K: Serialize, + V: Serialize, + I: IntoIterator, + { + unimplemented!() + } + + fn collect_str(self, _: &T) -> Result + where + T: ?Sized + core::fmt::Display, + { + unimplemented!() + } + + fn is_human_readable(&self) -> bool { + unimplemented!() + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/io.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/io.rs new file mode 100644 index 0000000..0ca402b --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/io.rs @@ -0,0 +1,483 @@ +use core::ops::Deref; + +pub(crate) use atomic16::assign_pid; +use embassy_futures::select::{select, select4, Either}; +use embassy_net::{ + dns::DnsQueryType, + tcp::{TcpReader, TcpSocket, TcpWriter}, + Stack, +}; +use embassy_sync::{ + blocking_mutex::raw::CriticalSectionRawMutex, + pubsub::{PubSubChannel, Subscriber, WaitResult}, +}; +use embassy_time::Timer; +use embedded_io_async::Write; +use mqttrs::{ + decode_slice, Connect, ConnectReturnCode, LastWill, Packet, Pid, Protocol, Publish, QoS, QosPid, +}; + +use crate::{ + device_id, fmt::Debug2Format, pipe::ConnectedPipe, ControlMessage, Error, MqttMessage, Payload, + Publishable, Topic, TopicString, CONFIRMATION_TIMEOUT, DATA_CHANNEL, DEFAULT_BACKOFF, + RESET_BACKOFF, +}; + +static SEND_QUEUE: ConnectedPipe = ConnectedPipe::new(); + +pub(crate) static CONTROL_CHANNEL: PubSubChannel = + PubSubChannel::new(); + +type ControlSubscriber = Subscriber<'static, CriticalSectionRawMutex, ControlMessage, 2, 5, 0>; + +pub(crate) async fn subscribe() -> ControlSubscriber { + loop { + if let Ok(sub) = CONTROL_CHANNEL.subscriber() { + return sub; + } + + Timer::after_millis(50).await; + } +} + +#[cfg(target_has_atomic = "16")] +mod atomic16 { + use core::sync::atomic::{AtomicU16, Ordering}; + + use mqttrs::Pid; + + static PID: AtomicU16 = AtomicU16::new(0); + + pub(crate) async fn assign_pid() -> Pid { + Pid::new() + PID.fetch_add(1, Ordering::SeqCst) + } +} + +#[cfg(not(target_has_atomic = "16"))] +mod atomic16 { + use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex}; + use mqttrs::Pid; + + static PID_MUTEX: Mutex = Mutex::new(0); + + pub(crate) async fn assign_pid() -> Pid { + let mut locked = PID_MUTEX.lock().await; + *locked += 1; + + Pid::new() + *locked + } +} + +pub(crate) async fn send_packet(packet: Packet<'_>) -> Result<(), Error> { + let mut buffer = Payload::new(); + + match buffer.encode_packet(&packet) { + Ok(()) => { + debug!( + "Sending packet to broker: {:?}", + Debug2Format(&packet.get_type()) + ); + SEND_QUEUE.push(buffer).await; + Ok(()) + } + Err(_) => { + error!("Failed to send packet"); + Err(Error::PacketError) + } + } +} + +pub(crate) async fn wait_for_publish( + mut subscriber: ControlSubscriber, + expected_pid: Pid, +) -> Result<(), Error> { + match select( + async { + loop { + match subscriber.next_message().await { + WaitResult::Lagged(_) => { + // Maybe we missed the message? + } + WaitResult::Message(ControlMessage::Published(published_pid)) => { + if published_pid == expected_pid { + return Ok(()); + } + } + _ => {} + } + } + }, + Timer::after_millis(CONFIRMATION_TIMEOUT), + ) + .await + { + Either::First(r) => r, + Either::Second(_) => Err(Error::TimedOut), + } +} + +pub(crate) async fn publish( + topic_name: &str, + payload: &[u8], + qos: QoS, + retain: bool, +) -> Result<(), Error> { + let subscriber = subscribe().await; + + let (qospid, pid) = match qos { + QoS::AtMostOnce => (QosPid::AtMostOnce, None), + QoS::AtLeastOnce => { + let pid = assign_pid().await; + (QosPid::AtLeastOnce(pid), Some(pid)) + } + QoS::ExactlyOnce => { + let pid = assign_pid().await; + (QosPid::ExactlyOnce(pid), Some(pid)) + } + }; + + let packet = Packet::Publish(Publish { + dup: false, + qospid, + retain, + topic_name, + payload, + }); + + send_packet(packet).await?; + + if let Some(expected_pid) = pid { + wait_for_publish(subscriber, expected_pid).await + } else { + Ok(()) + } +} + +fn packet_size(buffer: &[u8]) -> Option { + let mut pos = 1; + let mut multiplier = 1; + let mut value = 0; + + while pos < buffer.len() { + value += (buffer[pos] & 127) as usize * multiplier; + multiplier *= 128; + + if (buffer[pos] & 128) == 0 { + return Some(value + pos + 1); + } + + pos += 1; + if pos == 5 { + return Some(0); + } + } + + None +} + +/// The MQTT task that must be run in order for the stack to operate. +pub struct McutieTask<'t, T, L, const S: usize> +where + T: Deref + 't, + L: Publishable + 't, +{ + pub(crate) network: Stack<'t>, + pub(crate) broker: &'t str, + pub(crate) last_will: Option, + pub(crate) username: Option<&'t str>, + pub(crate) password: Option<&'t str>, + pub(crate) subscriptions: [Topic; S], + pub(crate) keep_alive: u16 +} + +impl<'t, T, L, const S: usize> McutieTask<'t, T, L, S> +where + T: Deref + 't, + L: Publishable + 't, +{ + #[cfg(not(feature = "homeassistant"))] + async fn ha_handle_update(&self, _topic: &Topic, _payload: &Payload) -> bool { + false + } + + async fn recv_loop(&self, mut reader: TcpReader<'_>) -> Result<(), Error> { + let mut buffer = [0_u8; 4096]; + let mut cursor: usize = 0; + + let controller = CONTROL_CHANNEL.immediate_publisher(); + + loop { + match reader.read(&mut buffer[cursor..]).await { + Ok(0) => { + error!("Receive socket closed"); + return Ok(()); + } + Ok(len) => { + cursor += len; + } + Err(_) => { + error!("I/O failure reading packet"); + return Err(Error::IOError); + } + } + + let mut start_pos = 0; + loop { + let packet_length = match packet_size(&buffer[start_pos..cursor]) { + Some(0) => { + error!("Invalid MQTT packet"); + return Err(Error::PacketError); + } + Some(len) => len, + None => { + // None is returned when there is not yet enough data to decode a packet. + if start_pos != 0 { + // Adjust the buffer to reclaim any unused data + buffer.copy_within(start_pos..cursor, 0); + cursor -= start_pos; + } + break; + } + }; + + let packet = match decode_slice(&buffer[start_pos..(start_pos + packet_length)]) { + Ok(Some(p)) => p, + Ok(None) => { + error!("Packet length calculation failed."); + return Err(Error::PacketError); + } + Err(_) => { + error!("Invalid MQTT packet"); + return Err(Error::PacketError); + } + }; + + debug!( + "Received packet from broker: {:?}", + Debug2Format(&packet.get_type()) + ); + + match packet { + Packet::Connack(connack) => match connack.code { + ConnectReturnCode::Accepted => { + #[cfg(feature = "homeassistant")] + self.ha_after_connected().await; + + for topic in &self.subscriptions { + let _ = topic.subscribe(false).await; + } + + DATA_CHANNEL.send(MqttMessage::Connected).await; + } + _ => { + error!("Connection request to broker was not accepted"); + return Err(Error::IOError); + } + }, + Packet::Pingresp => {} + + Packet::Publish(publish) => { + match ( + Topic::from_str(publish.topic_name), + Payload::from(publish.payload), + ) { + (Ok(topic), Ok(payload)) => { + if !self.ha_handle_update(&topic, &payload).await { + DATA_CHANNEL + .send(MqttMessage::Publish(topic, payload)) + .await; + } + } + _ => { + error!("Unable to process publish data as it was too large"); + } + } + + match publish.qospid { + mqttrs::QosPid::AtMostOnce => {} + mqttrs::QosPid::AtLeastOnce(pid) => { + send_packet(Packet::Puback(pid)).await?; + } + mqttrs::QosPid::ExactlyOnce(pid) => { + send_packet(Packet::Pubrec(pid)).await?; + } + } + } + Packet::Puback(pid) => { + controller.publish_immediate(ControlMessage::Published(pid)); + } + Packet::Pubrec(pid) => { + controller.publish_immediate(ControlMessage::Published(pid)); + send_packet(Packet::Pubrel(pid)).await?; + } + Packet::Pubrel(pid) => send_packet(Packet::Pubrel(pid)).await?, + Packet::Pubcomp(_) => {} + + Packet::Suback(suback) => { + if let Some(return_code) = suback.return_codes.first() { + controller.publish_immediate(ControlMessage::Subscribed( + suback.pid, + *return_code, + )); + } else { + warn!("Unexpected suback with no return codes"); + } + } + Packet::Unsuback(pid) => { + controller.publish_immediate(ControlMessage::Unsubscribed(pid)); + } + + Packet::Connect(_) + | Packet::Subscribe(_) + | Packet::Pingreq + | Packet::Unsubscribe(_) + | Packet::Disconnect => { + debug!( + "Unexpected packet from broker: {:?}", + Debug2Format(&packet.get_type()) + ); + } + } + + start_pos += packet_length; + if start_pos == cursor { + cursor = 0; + break; + } + } + } + } + + async fn write_loop(&self, mut writer: TcpWriter<'_>) { + let mut buffer = Payload::new(); + + let mut last_will_topic = TopicString::new(); + let mut last_will_payload = Payload::new(); + + let last_will = self.last_will.as_ref().and_then(|p| { + if p.write_topic(&mut last_will_topic).is_ok() + && p.write_payload(&mut last_will_payload).is_ok() + { + Some(LastWill { + topic: &last_will_topic, + message: &last_will_payload, + qos: p.qos(), + retain: p.retain(), + }) + } else { + None + } + }); + + // Send our connection request. + if buffer + .encode_packet(&Packet::Connect(Connect { + protocol: Protocol::MQTT311, + keep_alive: self.keep_alive, + client_id: device_id(), + clean_session: true, + last_will, + username: self.username, + password: self.password.map(|s| s.as_bytes()), + })) + .is_err() + { + error!("Failed to encode connection packet"); + return; + } + + if let Err(e) = writer.write_all(&buffer).await { + error!("Failed to send connection packet: {:?}", e); + return; + } + + let reader = SEND_QUEUE.reader(); + + loop { + let buffer = reader.receive().await; + + trace!("Writer sending packet"); + if let Err(e) = writer.write_all(&buffer).await { + error!("Failed to send data: {:?}", e); + return; + } + } + } + + /// Runs the MQTT stack. The future returned from this must be awaited for everything to work. + pub async fn run(self) { + let mut timeout: Option = None; + + let mut rx_buffer = [0; 4096]; + let mut tx_buffer = [0; 4096]; + + loop { + if let Some(millis) = timeout.replace(DEFAULT_BACKOFF) { + Timer::after_millis(millis).await; + } + + if !self.network.is_config_up() { + debug!("Waiting for network to configure."); + self.network.wait_config_up().await; + debug!("Network configured."); + } + + let ip_addrs = match self.network.dns_query(self.broker, DnsQueryType::A).await { + Ok(v) => v, + Err(e) => { + error!("Failed to lookup '{}' for broker: {:?}", self.broker, e); + continue; + } + }; + + let ip = match ip_addrs.first() { + Some(i) => *i, + None => { + error!("No IP address found for broker '{}'", self.broker); + continue; + } + }; + + debug!("Connecting to {}:1883", ip); + + let mut socket = TcpSocket::new(self.network, &mut rx_buffer, &mut tx_buffer); + if let Err(e) = socket.connect((ip, 1883)).await { + error!("Failed to connect to {}:1883: {:?}", ip, e); + continue; + } + + info!("Connected to {}", self.broker); + timeout = Some(RESET_BACKOFF); + + let (reader, writer) = socket.split(); + + let recv_loop = self.recv_loop(reader); + let send_loop = self.write_loop(writer); + + let ping_loop = async { + loop { + Timer::after_secs(45).await; + + let _ = send_packet(Packet::Pingreq).await; + } + }; + + let link_down = async { + self.network.wait_link_down().await; + warn!("Network link lost"); + }; + + let ip_down = async { + self.network.wait_config_down().await; + warn!("Network config lost"); + }; + + select4(send_loop, ping_loop, recv_loop, select(link_down, ip_down)).await; + + socket.close(); + + warn!("Lost connection with broker"); + DATA_CHANNEL.send(MqttMessage::Disconnected).await; + } + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/lib.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/lib.rs new file mode 100644 index 0000000..126c9af --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/lib.rs @@ -0,0 +1,227 @@ +#![no_std] +#![deny(unreachable_pub)] +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +//! MQTT client support crate vendored into this repository. + +use core::{ops::Deref, str}; + +pub use buffer::Buffer; +use embassy_net::{HardwareAddress, Stack}; +use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, channel::Channel}; +use heapless::String; +pub use io::McutieTask; +pub use mqttrs::QoS; +use mqttrs::{Pid, SubscribeReturnCodes}; +use once_cell::sync::OnceCell; +pub use publish::*; +pub use topic::Topic; + +// This must come first so the macros are visible +pub(crate) mod fmt; + +mod buffer; +#[cfg(feature = "homeassistant")] +pub mod homeassistant; +mod io; +mod pipe; +mod publish; +mod topic; + +// This really needs to match that used by mqttrs. +const TOPIC_LENGTH: usize = 256; +const PAYLOAD_LENGTH: usize = 2048; + +/// A fixed length stack allocated string. The length is fixed by the mqttrs crate. +pub type TopicString = String; +/// A fixed length buffer of 2048 bytes. +pub type Payload = Buffer; + +// By default in the event of an error connecting to the broker we will wait for 5s. +const DEFAULT_BACKOFF: u64 = 5000; +// If the connection dropped then re-connect more quickly. +const RESET_BACKOFF: u64 = 200; +// How long to wait for the broker to confirm actions. +const CONFIRMATION_TIMEOUT: u64 = 2000; + +static DATA_CHANNEL: Channel = Channel::new(); + +static DEVICE_TYPE: OnceCell> = OnceCell::new(); +static DEVICE_ID: OnceCell> = OnceCell::new(); + +fn device_id() -> &'static str { + DEVICE_ID.get().unwrap() +} + +fn device_type() -> &'static str { + DEVICE_TYPE.get().unwrap() +} + +/// Various errors +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + /// An IO error occured. + IOError, + /// The operation timed out. + TimedOut, + /// An attempt was made to encode something too large. + TooLarge, + /// A packet or payload could not be decoded or encoded. + PacketError, + /// An invalid or unsupported operation was attempted. + Invalid, + /// A value was rejected. + Rejected, +} + +#[allow(clippy::large_enum_variant)] +/// A message from the MQTT broker. +pub enum MqttMessage { + /// The broker has been connected to successfully. Generally in response to this message a + /// device should subscribe to topics of interest and send out any device state. + Connected, + /// New data received from the broker. + Publish(Topic, Payload), + /// The connection to the broker has been dropped. + Disconnected, + /// Home Assistant has come online and you should send any discovery messages. + #[cfg(feature = "homeassistant")] + HomeAssistantOnline, +} + +#[derive(Clone)] +enum ControlMessage { + Published(Pid), + Subscribed(Pid, SubscribeReturnCodes), + Unsubscribed(Pid), +} + +/// Receives messages from the broker. +pub struct McutieReceiver; + +impl McutieReceiver { + /// Waits for the next message from the broker. + pub async fn receive(&self) -> MqttMessage { + DATA_CHANNEL.receive().await + } +} + +/// A builder to configure the MQTT stack. +pub struct McutieBuilder<'t, T, L, const S: usize> +where + T: Deref + 't, + L: Publishable + 't, +{ + network: Stack<'t>, + device_type: &'t str, + device_id: Option<&'t str>, + broker: &'t str, + last_will: Option, + username: Option<&'t str>, + password: Option<&'t str>, + subscriptions: [Topic; S], +} + +impl<'t, T: Deref + 't, L: Publishable + 't> McutieBuilder<'t, T, L, 0> { + /// Creates a new builder with the initial required configuration. + /// + /// `device_type` is expected to be the same for all devices of the same type. + /// `broker` may be an IP address or a DNS name for the broker to connect to. + pub fn new(network: Stack<'t>, device_type: &'t str, broker: &'t str) -> Self { + Self { + network, + device_type, + broker, + device_id: None, + last_will: None, + username: None, + password: None, + subscriptions: [], + } + } +} + +impl<'t, T: Deref + 't, L: Publishable + 't, const S: usize> + McutieBuilder<'t, T, L, S> +{ + /// Add some default topics to subscribe to. + pub fn with_subscriptions( + self, + subscriptions: [Topic; N], + ) -> McutieBuilder<'t, T, L, N> { + McutieBuilder { + network: self.network, + device_type: self.device_type, + broker: self.broker, + device_id: self.device_id, + last_will: self.last_will, + username: self.username, + password: self.password, + subscriptions, + } + } +} + +impl<'t, T: Deref + 't, L: Publishable + 't, const S: usize> + McutieBuilder<'t, T, L, S> +{ + /// Adds authentication for the broker. + pub fn with_authentication(self, username: &'t str, password: &'t str) -> Self { + Self { + username: Some(username), + password: Some(password), + ..self + } + } + + /// Sets a last will message to be published in the event of disconnection. + pub fn with_last_will(self, last_will: L) -> Self { + Self { + last_will: Some(last_will), + ..self + } + } + + /// Sets a custom unique device identifier. If none is set then the network + /// MAC address is used. + pub fn with_device_id(self, device_id: &'t str) -> Self { + Self { + device_id: Some(device_id), + ..self + } + } + + /// Initialises the MQTT stack returning a receiver for listening to + /// messages from the broker and a future that must be run in order for the + /// stack to operate. + pub fn build(self, keep_alive: u16) -> (McutieReceiver, McutieTask<'t, T, L, S>) { + let mut dtype = String::<32>::new(); + dtype.push_str(self.device_type).unwrap(); + DEVICE_TYPE.set(dtype).unwrap(); + + let mut did = String::<32>::new(); + if let Some(device_id) = self.device_id { + did.push_str(device_id).unwrap(); + } else if let HardwareAddress::Ethernet(address) = self.network.hardware_address() { + let mut buffer = [0_u8; 12]; + hex::encode_to_slice(address.as_bytes(), &mut buffer).unwrap(); + did.push_str(str::from_utf8(&buffer).unwrap()).unwrap(); + } + + DEVICE_ID.set(did).unwrap(); + + ( + McutieReceiver {}, + McutieTask { + network: self.network, + broker: self.broker, + last_will: self.last_will, + username: self.username, + password: self.password, + subscriptions: self.subscriptions, + keep_alive + }, + ) + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/pipe.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/pipe.rs new file mode 100644 index 0000000..9df156f --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/pipe.rs @@ -0,0 +1,267 @@ +use core::{ + cell::RefCell, + future::Future, + pin::Pin, + task::{Context, Poll, Waker}, +}; + +use embassy_sync::blocking_mutex::{raw::RawMutex, Mutex}; +use pin_project::pin_project; + +struct PipeData { + connect_count: usize, + receiver_waker: Option, + sender_waker: Option, + pending: Option, +} + +fn swap_wakers(waker: &mut Option, new_waker: &Waker) { + if let Some(old_waker) = waker.take() { + if old_waker.will_wake(new_waker) { + *waker = Some(old_waker) + } else { + if !new_waker.will_wake(&old_waker) { + old_waker.wake(); + } + + *waker = Some(new_waker.clone()); + } + } else { + *waker = Some(new_waker.clone()) + } +} + +pub(crate) struct ReceiveFuture<'a, M: RawMutex, T, const N: usize> { + pipe: &'a ConnectedPipe, +} + +impl Future for ReceiveFuture<'_, M, T, N> { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.pipe.inner.lock(|cell| { + let mut inner = cell.borrow_mut(); + + if let Some(waker) = inner.sender_waker.take() { + waker.wake(); + } + + if let Some(item) = inner.pending.take() { + if let Some(old_waker) = inner.receiver_waker.take() { + old_waker.wake(); + } + + Poll::Ready(item) + } else { + swap_wakers(&mut inner.receiver_waker, cx.waker()); + Poll::Pending + } + }) + } +} + +pub(crate) struct PipeReader<'a, M: RawMutex, T, const N: usize> { + pipe: &'a ConnectedPipe, +} + +impl PipeReader<'_, M, T, N> { + #[must_use] + pub(crate) fn receive(&self) -> ReceiveFuture<'_, M, T, N> { + ReceiveFuture { pipe: self.pipe } + } +} + +impl Drop for PipeReader<'_, M, T, N> { + fn drop(&mut self) { + self.pipe.inner.lock(|cell| { + let mut inner = cell.borrow_mut(); + inner.connect_count -= 1; + + if inner.connect_count == 0 { + inner.pending = None; + } + + if let Some(waker) = inner.sender_waker.take() { + waker.wake(); + } + }) + } +} + +#[pin_project] +pub(crate) struct PushFuture<'a, M: RawMutex, T, const N: usize> { + data: Option, + pipe: &'a ConnectedPipe, +} + +impl Future for PushFuture<'_, M, T, N> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.pipe.inner.lock(|cell| { + let project = self.project(); + let mut inner = cell.borrow_mut(); + + if let Some(receiver) = inner.receiver_waker.take() { + receiver.wake(); + } + + if project.data.is_none() || inner.connect_count == 0 { + trace!("Dropping packet"); + Poll::Ready(()) + } else if inner.pending.is_some() { + swap_wakers(&mut inner.sender_waker, cx.waker()); + Poll::Pending + } else { + inner.pending = project.data.take(); + + Poll::Ready(()) + } + }) + } +} + +/// A pipe that knows whether a receiver is connected. If so pushing to the +/// queue waits until there is space in the queue, otherwise data is simply +/// dropped. +pub(crate) struct ConnectedPipe { + inner: Mutex>>, +} + +impl ConnectedPipe { + pub(crate) const fn new() -> Self { + Self { + inner: Mutex::new(RefCell::new(PipeData { + connect_count: 0, + receiver_waker: None, + sender_waker: None, + pending: None, + })), + } + } + + /// A future that waits for a new item to be available. + pub(crate) fn reader(&self) -> PipeReader<'_, M, T, N> { + self.inner.lock(|cell| { + let mut inner = cell.borrow_mut(); + inner.connect_count += 1; + + PipeReader { pipe: self } + }) + } + + /// Pushes an item to the reader, waiting for a slot to become available if + /// connected. + #[must_use] + pub(crate) fn push(&self, data: T) -> PushFuture<'_, M, T, N> { + PushFuture { + data: Some(data), + pipe: self, + } + } +} + +#[cfg(test)] +mod tests { + use core::time::Duration; + + use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex; + use futures_executor::{LocalPool, ThreadPool}; + use futures_timer::Delay; + use futures_util::{future::select, pin_mut, task::SpawnExt, FutureExt}; + + use super::ConnectedPipe; + + async fn wait_milis(milis: u64) { + Delay::new(Duration::from_millis(milis)).await; + } + + // #[futures_test::test] + #[test] + fn test_send_receive() { + let mut executor = LocalPool::new(); + let spawner = executor.spawner(); + + static PIPE: ConnectedPipe = ConnectedPipe::new(); + + // Task that sends + spawner + .spawn(async { + wait_milis(10).await; + + PIPE.push(23).await; + PIPE.push(56).await; + PIPE.push(67).await; + }) + .unwrap(); + + // Task that receives + spawner + .spawn(async { + let reader = PIPE.reader(); + let value = reader.receive().await; + assert_eq!(value, 23); + let value = reader.receive().await; + assert_eq!(value, 56); + let value = reader.receive().await; + assert_eq!(value, 67); + }) + .unwrap(); + + executor.run(); + } + + #[futures_test::test] + async fn test_send_drop() { + static PIPE: ConnectedPipe = ConnectedPipe::new(); + + PIPE.push(23).await; + PIPE.push(56).await; + PIPE.push(67).await; + + // Create a reader after sending + let reader = PIPE.reader(); + let receive = reader.receive().fuse(); + pin_mut!(receive); + + let timeout = wait_milis(50).fuse(); + pin_mut!(timeout); + + let either = select(receive, timeout).await; + + match either { + futures_util::future::Either::Left(_) => { + panic!("There should be nothing to receive!"); + } + futures_util::future::Either::Right(_) => {} + } + } + + #[futures_test::test] + async fn test_bulk_send_publish() { + static PIPE: ConnectedPipe = ConnectedPipe::new(); + + let executor = ThreadPool::new().unwrap(); + + executor + .spawn(async { + for i in 0..1000 { + PIPE.push(i).await; + } + }) + .unwrap(); + + executor + .spawn(async { + for i in 1000..2000 { + PIPE.push(i).await; + } + }) + .unwrap(); + + let reader = PIPE.reader(); + for _ in 0..800 { + reader.receive().await; + } + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/publish.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/publish.rs new file mode 100644 index 0000000..ef0ea14 --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/publish.rs @@ -0,0 +1,173 @@ +use core::{fmt::Display, future::Future, ops::Deref}; + +use embedded_io::Write; +use mqttrs::QoS; + +use crate::{io::publish, Error, Payload, Topic, TopicString}; + +/// A message that can be published to an MQTT broker. +pub trait Publishable { + /// Write this message's topic into the supplied buffer. + fn write_topic(&self, buffer: &mut TopicString) -> Result<(), Error>; + + /// Write this message's payload into the supplied buffer. + fn write_payload(&self, buffer: &mut Payload) -> Result<(), Error>; + + /// Get this message's QoS level. + fn qos(&self) -> QoS { + QoS::AtMostOnce + } + + /// Whether the broker should retain this message. + fn retain(&self) -> bool { + false + } + + /// Publishes this message to the broker. If the stack has not yet been + /// initialized this is likely to panic. + fn publish(&self) -> impl Future> { + async { + let mut topic = TopicString::new(); + self.write_topic(&mut topic)?; + + let mut payload = Payload::new(); + self.write_payload(&mut payload)?; + + publish(&topic, &payload, self.qos(), self.retain()).await + } + } +} + +/// A [`Publishable`] with a raw byte payload. +pub struct PublishBytes<'a, T, B: AsRef<[u8]>> { + pub(crate) topic: &'a Topic, + pub(crate) data: B, + pub(crate) qos: QoS, + pub(crate) retain: bool, +} + +impl> PublishBytes<'_, T, B> { + /// Sets the QoS level for this message. + pub fn qos(mut self, qos: QoS) -> Self { + self.qos = qos; + self + } + + /// Sets whether the broker should retain this message. + pub fn retain(mut self, retain: bool) -> Self { + self.retain = retain; + self + } +} + +impl<'a, T: Deref + 'a, B: AsRef<[u8]>> Publishable for PublishBytes<'a, T, B> { + fn write_topic(&self, buffer: &mut TopicString) -> Result<(), Error> { + self.topic.to_string(buffer) + } + + fn write_payload(&self, buffer: &mut Payload) -> Result<(), Error> { + buffer + .write_all(self.data.as_ref()) + .map_err(|_| Error::TooLarge) + } + + fn qos(&self) -> QoS { + self.qos + } + + fn retain(&self) -> bool { + self.retain + } + + async fn publish(&self) -> Result<(), Error> { + let mut topic = TopicString::new(); + self.write_topic(&mut topic)?; + + publish(&topic, self.data.as_ref(), self.qos(), self.retain()).await + } +} + +/// A [`Publishable`] with a payload that implements [`Display`]. +pub struct PublishDisplay<'a, T, D: Display> { + pub(crate) topic: &'a Topic, + pub(crate) data: D, + pub(crate) qos: QoS, + pub(crate) retain: bool, +} + +impl PublishDisplay<'_, T, D> { + /// Sets the QoS level for this message. + pub fn qos(mut self, qos: QoS) -> Self { + self.qos = qos; + self + } + + /// Sets whether the broker should retain this message. + pub fn retain(mut self, retain: bool) -> Self { + self.retain = retain; + self + } +} + +impl<'a, T: Deref + 'a, D: Display> Publishable for PublishDisplay<'a, T, D> { + fn write_topic(&self, buffer: &mut TopicString) -> Result<(), Error> { + self.topic.to_string(buffer) + } + + fn write_payload(&self, buffer: &mut Payload) -> Result<(), Error> { + write!(buffer, "{}", self.data).map_err(|_| Error::TooLarge) + } + + fn qos(&self) -> QoS { + self.qos + } + + fn retain(&self) -> bool { + self.retain + } +} + +#[cfg(feature = "serde")] +/// A [`Publishable`] with that serializes a JSON payload. +pub struct PublishJson<'a, T, D: serde::Serialize> { + pub(crate) topic: &'a Topic, + pub(crate) data: D, + pub(crate) qos: QoS, + pub(crate) retain: bool, +} + +#[cfg(feature = "serde")] +impl PublishJson<'_, T, D> { + /// Sets the QoS level for this message. + pub fn qos(mut self, qos: QoS) -> Self { + self.qos = qos; + self + } + + /// Sets whether the broker should retain this message. + pub fn retain(mut self, retain: bool) -> Self { + self.retain = retain; + self + } +} + +#[cfg(feature = "serde")] +impl<'a, T: Deref + 'a, D: serde::Serialize> Publishable for PublishJson<'a, T, D> { + fn write_topic(&self, buffer: &mut TopicString) -> Result<(), Error> { + self.topic.to_string(buffer) + } + + fn write_payload(&self, buffer: &mut Payload) -> Result<(), Error> { + buffer + .serialize_json(&self.data) + .map_err(|_| Error::TooLarge) + } + + fn qos(&self) -> QoS { + self.qos + } + + fn retain(&self) -> bool { + self.retain + } +} diff --git a/Software/MainBoard/rust/src/mcutie_3_0_0/topic.rs b/Software/MainBoard/rust/src/mcutie_3_0_0/topic.rs new file mode 100644 index 0000000..259fd5b --- /dev/null +++ b/Software/MainBoard/rust/src/mcutie_3_0_0/topic.rs @@ -0,0 +1,284 @@ +use core::{fmt::Display, ops::Deref}; + +use embassy_futures::select::{select, Either}; +use embassy_sync::pubsub::WaitResult; +use embassy_time::Timer; +use heapless::{String, Vec}; +use mqttrs::{Packet, QoS, Subscribe, SubscribeReturnCodes, SubscribeTopic, Unsubscribe}; + +#[cfg(feature = "serde")] +use crate::publish::PublishJson; +use crate::{ + device_id, device_type, + io::{assign_pid, send_packet, subscribe}, + publish::{PublishBytes, PublishDisplay}, + ControlMessage, Error, TopicString, CONFIRMATION_TIMEOUT, +}; + +/// An MQTT topic that is optionally prefixed with the device type and unique ID. +/// Normally you will define all your application's topics as consts with static +/// lifetimes. +/// +/// A [`Topic`] is the main entry to publishing messages to the broker. +/// +/// ``` +/// # use mcutie::{Publishable, Topic}; +/// const DEVICE_AVAILABILITY: Topic<&'static str> = Topic::Device("state"); +/// +/// async fn send_status(status: &'static str) { +/// let _ = DEVICE_AVAILABILITY.with_bytes(status.as_bytes()).publish().await; +/// } +/// ``` +#[derive(Clone, Copy)] +pub enum Topic { + /// A topic that is prefixed with the device type. + DeviceType(T), + /// A topic that is prefixed with the device type and unique ID. + Device(T), + /// Any topic. + General(T), +} + +impl PartialEq> for Topic +where + B: PartialEq, +{ + fn eq(&self, other: &Topic) -> bool { + match (self, other) { + (Topic::DeviceType(l0), Topic::DeviceType(r0)) => l0 == r0, + (Topic::Device(l0), Topic::Device(r0)) => l0 == r0, + (Topic::General(l0), Topic::General(r0)) => l0 == r0, + _ => false, + } + } +} + +impl Topic { + /// Creates a publishable message with something that can return a reference + /// to the payload in bytes. + /// + /// Defaults to non-retained with QoS of 0 (AtMostOnce). + pub fn with_bytes>(&self, data: B) -> PublishBytes<'_, T, B> { + PublishBytes { + topic: self, + data, + qos: QoS::AtMostOnce, + retain: false, + } + } + + /// Creates a publishable message with something that implements [`Display`]. + /// + /// Defaults to non-retained with QoS of 0 (AtMostOnce). + pub fn with_display(&self, data: D) -> PublishDisplay<'_, T, D> { + PublishDisplay { + topic: self, + data, + qos: QoS::AtMostOnce, + retain: false, + } + } + + #[cfg(feature = "serde")] + /// Creates a publishable message with something that can be serialized to + /// JSON. + /// + /// Defaults to non-retained with QoS of 0 (AtMostOnce). + pub fn with_json(&self, data: D) -> PublishJson<'_, T, D> { + PublishJson { + topic: self, + data, + qos: QoS::AtMostOnce, + retain: false, + } + } +} + +impl Topic { + pub(crate) fn from_str(mut st: &str) -> Result { + let mut strip_prefix = |pr: &str| -> bool { + if st.starts_with(pr) && st.len() > pr.len() && &st[pr.len()..pr.len() + 1] == "/" { + st = &st[pr.len() + 1..]; + true + } else { + false + } + }; + + if strip_prefix(device_type()) { + if strip_prefix(device_id()) { + let mut topic = TopicString::new(); + topic.push_str(st).map_err(|_| Error::TooLarge)?; + Ok(Topic::Device(topic)) + } else { + let mut topic = TopicString::new(); + topic.push_str(st).map_err(|_| Error::TooLarge)?; + Ok(Topic::DeviceType(topic)) + } + } else { + let mut topic = TopicString::new(); + topic.push_str(st).map_err(|_| Error::TooLarge)?; + Ok(Topic::General(topic)) + } + } +} + +impl> Topic { + pub(crate) fn to_string(&self, result: &mut String) -> Result<(), Error> { + match self { + Topic::Device(st) => { + result + .push_str(device_type()) + .map_err(|_| Error::TooLarge)?; + result.push_str("/").map_err(|_| Error::TooLarge)?; + result.push_str(device_id()).map_err(|_| Error::TooLarge)?; + result.push_str("/").map_err(|_| Error::TooLarge)?; + result.push_str(st.as_ref()).map_err(|_| Error::TooLarge)?; + } + Topic::DeviceType(st) => { + result + .push_str(device_type()) + .map_err(|_| Error::TooLarge)?; + result.push_str("/").map_err(|_| Error::TooLarge)?; + result.push_str(st.as_ref()).map_err(|_| Error::TooLarge)?; + } + Topic::General(st) => { + result.push_str(st.as_ref()).map_err(|_| Error::TooLarge)?; + } + } + + Ok(()) + } + + /// Converts to a topic containing an [`str`]. Particularly useful for converting from an owned + /// string for match patterns. + pub fn as_ref(&self) -> Topic<&str> { + match self { + Topic::DeviceType(st) => Topic::DeviceType(st.as_ref()), + Topic::Device(st) => Topic::Device(st.as_ref()), + Topic::General(st) => Topic::General(st.as_ref()), + } + } + + /// Subscribes to this topic. If `wait_for_ack` is true then this will wait until confirmation + /// is received from the broker before returning. + pub async fn subscribe(&self, wait_for_ack: bool) -> Result<(), Error> { + let mut subscriber = subscribe().await; + + let mut topic_path = TopicString::new(); + if self.to_string(&mut topic_path).is_err() { + return Err(Error::TooLarge); + } + + let pid = assign_pid().await; + + let mut subscribe_topic_path = String::<256>::new(); + subscribe_topic_path + .push_str(topic_path.as_str()) + .map_err(|_| Error::TooLarge)?; + let subscribe_topic = SubscribeTopic { + topic_path: subscribe_topic_path, + qos: QoS::AtLeastOnce, + }; + + // The size of this vec must match that used by mqttrs. + let topics = match Vec::::from_slice(&[subscribe_topic]) { + Ok(t) => t, + Err(_) => return Err(Error::TooLarge), + }; + + let packet = Packet::Subscribe(Subscribe { pid, topics }); + + send_packet(packet).await?; + + if wait_for_ack { + match select( + async { + loop { + match subscriber.next_message().await { + WaitResult::Lagged(_) => { + // Maybe we missed the message? + } + WaitResult::Message(ControlMessage::Subscribed( + subscribed_pid, + return_code, + )) => { + if subscribed_pid == pid { + if matches!(return_code, SubscribeReturnCodes::Success(_)) { + return Ok(()); + } else { + return Err(Error::IOError); + } + } + } + _ => {} + } + } + }, + Timer::after_millis(CONFIRMATION_TIMEOUT), + ) + .await + { + Either::First(r) => r, + Either::Second(_) => Err(Error::TimedOut), + } + } else { + Ok(()) + } + } + + /// Unsubscribes from a topic. If `wait_for_ack` is true then this will wait until confirmation is + /// received from the broker before returning. + pub async fn unsubscribe(&self, wait_for_ack: bool) -> Result<(), Error> { + let mut subscriber = subscribe().await; + + let mut topic_path = TopicString::new(); + if self.to_string(&mut topic_path).is_err() { + return Err(Error::TooLarge); + } + + let pid = assign_pid().await; + + // The size of this vec must match that used by mqttrs. + let mut unsubscribe_topic_path = String::<256>::new(); + unsubscribe_topic_path + .push_str(topic_path.as_str()) + .map_err(|_| Error::TooLarge)?; + let topics = match Vec::, 5>::from_slice(&[unsubscribe_topic_path]) { + Ok(t) => t, + Err(_) => return Err(Error::TooLarge), + }; + + let packet = Packet::Unsubscribe(Unsubscribe { pid, topics }); + + send_packet(packet).await?; + + if wait_for_ack { + match select( + async { + loop { + match subscriber.next_message().await { + WaitResult::Lagged(_) => { + // Maybe we missed the message? + } + WaitResult::Message(ControlMessage::Unsubscribed(subscribed_pid)) => { + if subscribed_pid == pid { + return Ok(()); + } + } + _ => {} + } + } + }, + Timer::after_millis(CONFIRMATION_TIMEOUT), + ) + .await + { + Either::First(r) => r, + Either::Second(_) => Err(Error::TimedOut), + } + } else { + Ok(()) + } + } +}