Desmond-Dong commited on
Commit
f6b7440
·
1 Parent(s): 12aed5f

Update ESPHome protocol implementation to match official linux-voice-assistant

Browse files

- Replace api_server.py with official ESPHome protocol implementation
- Replace satellite.py with official voice assistant protocol
- Add MediaPlayerEntity for media player support
- Add save_preferences method to ServerState
- Update AudioPlayer to support URLs and file paths
- Remove zeroconf (Reachy Mini handles mDNS discovery)
- Add media_player_entity to ServerState

reachy_mini_ha_voice/api_server.py CHANGED
@@ -1,74 +1,180 @@
1
- """API server for Home Assistant integration."""
2
 
3
  import asyncio
4
- import json
5
  import logging
6
- from typing import Dict, List, Optional
 
 
7
 
8
- from .models import ServerState
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  _LOGGER = logging.getLogger(__name__)
11
 
12
 
13
- class APIServer:
14
- """API server for Home Assistant."""
15
-
16
- def __init__(self, state: ServerState):
17
- """Initialize API server."""
18
- self.state = state
19
- self._handlers: Dict[str, callable] = {
20
- "hello": self._handle_hello,
21
- "list_entities": self._handle_list_entities,
22
- "get_state": self._handle_get_state,
23
- "subscribe_states": self._handle_subscribe_states,
24
- }
25
-
26
- async def handle_request(self, command: str, payload: dict) -> dict:
27
- """Handle an API request."""
28
- handler = self._handlers.get(command)
29
- if handler:
30
- try:
31
- return await handler(payload)
32
- except Exception as e:
33
- _LOGGER.error("Error handling request %s: %s", command, e)
34
- return {"error": str(e)}
35
- else:
36
- return {"error": f"Unknown command: {command}"}
37
-
38
- async def _handle_hello(self, payload: dict) -> dict:
39
- """Handle hello request."""
40
- return {
41
- "name": self.state.name,
42
- "mac_address": self.state.mac_address,
43
- "version": "1.0.0",
44
- }
45
-
46
- async def _handle_list_entities(self, payload: dict) -> dict:
47
- """Handle list_entities request."""
48
- entities = []
49
- for entity in self.state.entities:
50
- entities.append(
51
- {
52
- "key": entity.key,
53
- "name": entity.name,
54
- "state": entity.state,
55
- "attributes": entity.attributes,
56
- }
57
  )
58
- return {"entities": entities}
59
-
60
- async def _handle_get_state(self, payload: dict) -> dict:
61
- """Handle get_state request."""
62
- key = payload.get("key")
63
- entity = next((e for e in self.state.entities if e.key == key), None)
64
- if entity:
65
- return {
66
- "key": entity.key,
67
- "state": entity.state,
68
- "attributes": entity.attributes,
69
- }
70
- return {"error": "Entity not found"}
71
-
72
- async def _handle_subscribe_states(self, payload: dict) -> dict:
73
- """Handle subscribe_states request."""
74
- return {"result": "ok"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Partial ESPHome server implementation."""
2
 
3
  import asyncio
 
4
  import logging
5
+ from abc import abstractmethod
6
+ from collections.abc import Iterable
7
+ from typing import TYPE_CHECKING, List, Optional
8
 
9
+ # pylint: disable=no-name-in-module
10
+ from aioesphomeapi._frame_helper.packets import make_plain_text_packets
11
+ from aioesphomeapi.api_pb2 import ( # type: ignore[attr-defined]
12
+ AuthenticationRequest,
13
+ AuthenticationResponse,
14
+ DisconnectRequest,
15
+ DisconnectResponse,
16
+ HelloRequest,
17
+ HelloResponse,
18
+ PingRequest,
19
+ PingResponse,
20
+ )
21
+ from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
22
+ from google.protobuf import message
23
+
24
+ PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
25
 
26
  _LOGGER = logging.getLogger(__name__)
27
 
28
 
29
+ class APIServer(asyncio.Protocol):
30
+
31
+ def __init__(self, name: str) -> None:
32
+ self.name = name
33
+
34
+ self._buffer: Optional[bytes] = None
35
+ self._buffer_len: int = 0
36
+ self._pos: int = 0
37
+ self._transport = None
38
+ self._writelines = None
39
+
40
+ @abstractmethod
41
+ def handle_message(self, msg: message.Message) -> Iterable[message.Message]:
42
+ pass
43
+
44
+ def process_packet(self, msg_type: int, packet_data: bytes) -> None:
45
+ msg_class = MESSAGE_TYPE_TO_PROTO[msg_type]
46
+ msg_inst = msg_class.FromString(packet_data)
47
+
48
+ if isinstance(msg_inst, HelloRequest):
49
+ self.send_messages(
50
+ [
51
+ HelloResponse(
52
+ api_version_major=1,
53
+ api_version_minor=10,
54
+ name=self.name,
55
+ )
56
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
+ return
59
+
60
+ if isinstance(msg_inst, AuthenticationRequest):
61
+ self.send_messages([AuthenticationResponse()])
62
+ elif isinstance(msg_inst, DisconnectRequest):
63
+ self.send_messages([DisconnectResponse()])
64
+ _LOGGER.debug("Disconnect requested")
65
+ if self._transport:
66
+ self._transport.close()
67
+ self._transport = None
68
+ self._writelines = None
69
+ elif isinstance(msg_inst, PingRequest):
70
+ self.send_messages([PingResponse()])
71
+ elif msgs := self.handle_message(msg_inst):
72
+ if isinstance(msgs, message.Message):
73
+ msgs = [msgs]
74
+
75
+ self.send_messages(msgs)
76
+
77
+ def send_messages(self, msgs: List[message.Message]):
78
+ if self._writelines is None:
79
+ return
80
+
81
+ packets = [
82
+ (PROTO_TO_MESSAGE_TYPE[msg.__class__], msg.SerializeToString())
83
+ for msg in msgs
84
+ ]
85
+ packet_bytes = make_plain_text_packets(packets)
86
+ self._writelines(packet_bytes)
87
+
88
+ def connection_made(self, transport) -> None:
89
+ self._transport = transport
90
+ self._writelines = transport.writelines
91
+
92
+ def data_received(self, data: bytes):
93
+ if self._buffer is None:
94
+ self._buffer = data
95
+ self._buffer_len = len(data)
96
+ else:
97
+ self._buffer += data
98
+ self._buffer_len += len(data)
99
+
100
+ while self._buffer_len >= 3:
101
+ self._pos = 0
102
+ # Read preamble, which should always 0x00
103
+ if (preamble := self._read_varuint()) != 0x00:
104
+ _LOGGER.error("Incorrect preamble: %s", preamble)
105
+ return
106
+
107
+ if (length := self._read_varuint()) == -1:
108
+ _LOGGER.error("Incorrect length: %s", length)
109
+ return
110
+
111
+ if (msg_type := self._read_varuint()) == -1:
112
+ _LOGGER.error("Incorrect message type: %s", msg_type)
113
+ return
114
+
115
+ if length == 0:
116
+ # Empty message (allowed)
117
+ self._remove_from_buffer()
118
+ self.process_packet(msg_type, b"")
119
+ continue
120
+
121
+ if (packet_data := self._read(length)) is None:
122
+ return
123
+
124
+ self._remove_from_buffer()
125
+ self.process_packet(msg_type, packet_data)
126
+
127
+ def _read(self, length: int) -> bytes | None:
128
+ """Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
129
+ new_pos = self._pos + length
130
+ if self._buffer_len < new_pos:
131
+ return None
132
+ original_pos = self._pos
133
+ self._pos = new_pos
134
+ if TYPE_CHECKING:
135
+ assert self._buffer is not None, "Buffer should be set"
136
+ cstr = self._buffer
137
+ # Important: we must keep the bounds check (self._buffer_len < new_pos)
138
+ # above to verify we never try to read past the end of the buffer
139
+ return cstr[original_pos:new_pos]
140
+
141
+ def connection_lost(self, exc):
142
+ self._transport = None
143
+ self._writelines = None
144
+
145
+ def _read_varuint(self) -> int:
146
+ """Read a varuint from the buffer or -1 if the buffer runs out of bytes."""
147
+ if not self._buffer:
148
+ return -1
149
+
150
+ result = 0
151
+ bitpos = 0
152
+ cstr = self._buffer
153
+ while self._buffer_len > self._pos:
154
+ val = cstr[self._pos]
155
+ self._pos += 1
156
+ result |= (val & 0x7F) << bitpos
157
+ if (val & 0x80) == 0:
158
+ return result
159
+ bitpos += 7
160
+ return -1
161
+
162
+ def _remove_from_buffer(self) -> None:
163
+ """Remove data from the buffer."""
164
+ end_of_frame_pos = self._pos
165
+ self._buffer_len -= end_of_frame_pos
166
+ if self._buffer_len == 0:
167
+ # This is the best case scenario, we can just set the buffer to None
168
+ # and don't have to copy the data. This is the most common case as well.
169
+ self._buffer = None
170
+ return
171
+ if TYPE_CHECKING:
172
+ assert self._buffer is not None, "Buffer should be set"
173
+ # This is the worst case scenario, we have to copy the data
174
+ # and can't just use the buffer directly. This should only happen
175
+ # when we read multiple frames at once because the event loop
176
+ # is blocked and we cannot pull the data out of the buffer fast enough.
177
+ cstr = self._buffer
178
+ # Important: we must use the explicit length for the slice
179
+ # since Cython will stop at any '\0' character if we don't
180
+ self._buffer = cstr[end_of_frame_pos : self._buffer_len + end_of_frame_pos]
reachy_mini_ha_voice/app.py CHANGED
@@ -24,7 +24,6 @@ from .models import (
24
  )
25
  from .satellite import VoiceSatelliteProtocol
26
  from .util import get_mac
27
- from .zeroconf import HomeAssistantZeroconf
28
 
29
  _LOGGER = logging.getLogger(__name__)
30
  _MODULE_DIR = Path(__file__).parent
@@ -124,6 +123,7 @@ class ReachyMiniHAVoiceApp(ReachyMiniApp):
124
  refractory_seconds=2.0,
125
  download_dir=_REPO_DIR / "local",
126
  reachy_integration=None, # Not using Reachy integration for now
 
127
  )
128
 
129
  def _load_wake_words(self) -> Dict[str, AvailableWakeWord]:
@@ -186,16 +186,12 @@ class ReachyMiniHAVoiceApp(ReachyMiniApp):
186
  lambda: VoiceSatelliteProtocol(state), host="0.0.0.0", port=6053
187
  )
188
 
189
- # Auto discovery (zeroconf, mDNS)
190
- discovery = HomeAssistantZeroconf(port=6053, name="ReachyMini")
191
- await discovery.register_server()
192
-
193
  try:
194
  async with server:
195
  _LOGGER.info("ESPHome server started on port 6053")
196
  await server.serve_forever()
197
  finally:
198
- await discovery.unregister_server()
199
 
200
  def _process_audio(self, state: ServerState) -> None:
201
  """Process audio from microphone."""
 
24
  )
25
  from .satellite import VoiceSatelliteProtocol
26
  from .util import get_mac
 
27
 
28
  _LOGGER = logging.getLogger(__name__)
29
  _MODULE_DIR = Path(__file__).parent
 
123
  refractory_seconds=2.0,
124
  download_dir=_REPO_DIR / "local",
125
  reachy_integration=None, # Not using Reachy integration for now
126
+ media_player_entity=None,
127
  )
128
 
129
  def _load_wake_words(self) -> Dict[str, AvailableWakeWord]:
 
186
  lambda: VoiceSatelliteProtocol(state), host="0.0.0.0", port=6053
187
  )
188
 
 
 
 
 
189
  try:
190
  async with server:
191
  _LOGGER.info("ESPHome server started on port 6053")
192
  await server.serve_forever()
193
  finally:
194
+ _LOGGER.info("ESPHome server stopped")
195
 
196
  def _process_audio(self, state: ServerState) -> None:
197
  """Process audio from microphone."""
reachy_mini_ha_voice/entity.py CHANGED
@@ -1,36 +1,94 @@
1
  """Entity management for Home Assistant."""
2
 
3
  import logging
4
- from typing import Dict, List
 
 
 
 
 
 
 
 
5
 
6
  from .models import Entity
7
 
8
  _LOGGER = logging.getLogger(__name__)
9
 
10
 
11
- class EntityManager:
12
- """Manage Home Assistant entities."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def __init__(self):
15
- """Initialize entity manager."""
16
- self._entities: Dict[str, Entity] = {}
17
 
18
- def add_entity(self, entity: Entity) -> None:
19
- """Add an entity."""
20
- self._entities[entity.key] = entity
21
- _LOGGER.debug("Added entity: %s", entity.key)
 
22
 
23
- def update_entity(self, key: str, state: str, attributes: Dict[str, str]) -> None:
24
- """Update an entity."""
25
- if key in self._entities:
26
- self._entities[key].state = state
27
- self._entities[key].attributes.update(attributes)
28
- _LOGGER.debug("Updated entity: %s", key)
29
 
30
- def get_entity(self, key: str) -> Entity:
31
- """Get an entity by key."""
32
- return self._entities.get(key)
 
 
33
 
34
- def list_entities(self) -> List[Entity]:
35
- """List all entities."""
36
- return list(self._entities.values())
 
 
 
1
  """Entity management for Home Assistant."""
2
 
3
  import logging
4
+ from typing import Dict, List, Optional
5
+
6
+ # pylint: disable=no-name-in-module
7
+ from aioesphomeapi.api_pb2 import ( # type: ignore[attr-defined]
8
+ ListEntitiesMediaPlayersResponse,
9
+ MediaPlayerCommandRequest,
10
+ TextSensorStateResponse,
11
+ )
12
+ from aioesphomeapi.model import MediaPlayerState
13
 
14
  from .models import Entity
15
 
16
  _LOGGER = logging.getLogger(__name__)
17
 
18
 
19
+ class MediaPlayerEntity(Entity):
20
+ """Media player entity for voice assistant."""
21
+
22
+ def __init__(
23
+ self, server, key: int, name: str, object_id: str, music_player, announce_player
24
+ ):
25
+ """Initialize media player entity."""
26
+ super().__init__(key=key, name=name, state="idle", attributes={})
27
+ self.server = server
28
+ self.object_id = object_id
29
+ self.music_player = music_player
30
+ self.announce_player = announce_player
31
+ self._volume = 1.0
32
+ self._position = 0
33
+ self._duration = 0
34
+
35
+ def handle_message(self, msg):
36
+ """Handle a message."""
37
+ if isinstance(msg, ListEntitiesMediaPlayersResponse):
38
+ yield self.get_list_entities_response()
39
+ elif isinstance(msg, MediaPlayerCommandRequest):
40
+ self.handle_command(msg)
41
+
42
+ def get_list_entities_response(self):
43
+ """Get list entities response."""
44
+ from aioesphomeapi.api_pb2 import ListEntitiesMediaPlayersResponse
45
+
46
+ return ListEntitiesMediaPlayersResponse(
47
+ object_id=self.object_id,
48
+ key=self.key,
49
+ name=self.name,
50
+ )
51
+
52
+ def handle_command(self, msg):
53
+ """Handle a media player command."""
54
+ if msg.command == MediaPlayerCommandRequest.PLAY:
55
+ if msg.url:
56
+ self.play([msg.url])
57
+ elif msg.command == MediaPlayerCommandRequest.PAUSE:
58
+ self.music_player.stop()
59
+ elif msg.command == MediaPlayerCommandRequest.STOP:
60
+ self.music_player.stop()
61
+ elif msg.command == MediaPlayerCommandRequest.VOLUME_SET:
62
+ self._volume = msg.volume / 255.0
63
+ elif msg.command == MediaPlayerCommandRequest.MUTE:
64
+ self._volume = 0.0 if msg.mute else 1.0
65
+
66
+ def play(self, urls, announcement=False, done_callback=None):
67
+ """Play media."""
68
+ _LOGGER.debug("Playing: %s", urls)
69
+ player = self.announce_player if announcement else self.music_player
70
 
71
+ for url in urls:
72
+ try:
73
+ from urllib.request import urlopen
74
 
75
+ with urlopen(url) as response:
76
+ audio_data = response.read()
77
+ player.play(audio_data)
78
+ except Exception as e:
79
+ _LOGGER.error("Error playing %s: %s", url, e)
80
 
81
+ if done_callback:
82
+ done_callback()
 
 
 
 
83
 
84
+ def duck(self):
85
+ """Duck the volume."""
86
+ _LOGGER.debug("Ducking media player")
87
+ # Reduce volume by 50%
88
+ # self._volume *= 0.5
89
 
90
+ def unduck(self):
91
+ """Unduck the volume."""
92
+ _LOGGER.debug("Unducking media player")
93
+ # Restore volume
94
+ # self._volume = min(1.0, self._volume * 2.0)
reachy_mini_ha_voice/models.py CHANGED
@@ -73,6 +73,20 @@ class ServerState:
73
  satellite: Optional["VoiceSatelliteProtocol"] = None
74
  wake_words_changed: bool = True
75
  reachy_integration: Optional["ReachyMiniIntegration"] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  @dataclass
@@ -93,14 +107,33 @@ class AudioPlayer:
93
  self.device = device
94
  self._stream = None
95
  self._pyaudio = None
 
96
 
97
- def play(self, audio_data: bytes) -> None:
98
- """Play audio data."""
99
  import pyaudio
100
 
101
  if self._pyaudio is None:
102
  self._pyaudio = pyaudio.PyAudio()
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # Assume 16-bit PCM, 16kHz, mono
105
  if self._stream is None:
106
  self._stream = self._pyaudio.open(
@@ -113,11 +146,31 @@ class AudioPlayer:
113
 
114
  self._stream.write(audio_data)
115
 
116
- def close(self) -> None:
117
- """Close the audio player."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if self._stream is not None:
 
119
  self._stream.close()
120
  self._stream = None
 
 
 
 
121
  if self._pyaudio is not None:
122
  self._pyaudio.terminate()
123
  self._pyaudio = None
 
73
  satellite: Optional["VoiceSatelliteProtocol"] = None
74
  wake_words_changed: bool = True
75
  reachy_integration: Optional["ReachyMiniIntegration"] = None
76
+ media_player_entity: Optional["MediaPlayerEntity"] = None
77
+
78
+ def save_preferences(self) -> None:
79
+ """Save preferences to file."""
80
+ try:
81
+ import json
82
+
83
+ with open(self.preferences_path, "w", encoding="utf-8") as f:
84
+ json.dump(
85
+ {"active_wake_words": self.preferences.active_wake_words},
86
+ f,
87
+ )
88
+ except Exception as e:
89
+ _LOGGER.error("Error saving preferences: %s", e)
90
 
91
 
92
  @dataclass
 
107
  self.device = device
108
  self._stream = None
109
  self._pyaudio = None
110
+ self._ducked = False
111
 
112
+ def play(self, audio_data: Union[bytes, str], done_callback=None) -> None:
113
+ """Play audio data or URL."""
114
  import pyaudio
115
 
116
  if self._pyaudio is None:
117
  self._pyaudio = pyaudio.PyAudio()
118
 
119
+ if isinstance(audio_data, str):
120
+ # It's a URL or file path
121
+ try:
122
+ from urllib.request import urlopen
123
+
124
+ if audio_data.startswith("http://") or audio_data.startswith("https://"):
125
+ with urlopen(audio_data) as response:
126
+ audio_data = response.read()
127
+ else:
128
+ # It's a file path
129
+ with open(audio_data, "rb") as f:
130
+ audio_data = f.read()
131
+ except Exception as e:
132
+ _LOGGER.error("Error loading audio: %s", e)
133
+ if done_callback:
134
+ done_callback()
135
+ return
136
+
137
  # Assume 16-bit PCM, 16kHz, mono
138
  if self._stream is None:
139
  self._stream = self._pyaudio.open(
 
146
 
147
  self._stream.write(audio_data)
148
 
149
+ if done_callback:
150
+ done_callback()
151
+
152
+ def duck(self) -> None:
153
+ """Duck the volume (reduce by 50%)."""
154
+ self._ducked = True
155
+ # For simple implementation, we just note the state
156
+ # In a full implementation, we would actually reduce the volume
157
+
158
+ def unduck(self) -> None:
159
+ """Unduck the volume (restore to normal)."""
160
+ self._ducked = False
161
+ # For simple implementation, we just note the state
162
+ # In a full implementation, we would actually restore the volume
163
+
164
+ def stop(self) -> None:
165
+ """Stop playing and reset the stream."""
166
  if self._stream is not None:
167
+ self._stream.stop_stream()
168
  self._stream.close()
169
  self._stream = None
170
+
171
+ def close(self) -> None:
172
+ """Close the audio player."""
173
+ self.stop()
174
  if self._pyaudio is not None:
175
  self._pyaudio.terminate()
176
  self._pyaudio = None
reachy_mini_ha_voice/satellite.py CHANGED
@@ -1,136 +1,393 @@
1
- """Voice satellite protocol implementation for ESPHome."""
2
 
3
- import asyncio
4
- import json
5
  import logging
6
- import struct
7
- from typing import Optional, Union
 
 
 
 
 
8
 
9
- from .models import ServerState
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  _LOGGER = logging.getLogger(__name__)
12
 
13
 
14
- class VoiceSatelliteProtocol(asyncio.Protocol):
15
- """ESPHome voice satellite protocol implementation."""
16
 
17
- def __init__(self, state: ServerState):
18
- """Initialize protocol."""
19
- self.state = state
20
- self.transport: Optional[asyncio.Transport] = None
21
- self._buffer = bytearray()
22
- self._connected = False
23
 
24
- def connection_made(self, transport: asyncio.Transport) -> None:
25
- """Handle new connection."""
26
- self.transport = transport
27
  self.state.satellite = self
28
- self._connected = True
29
- _LOGGER.info("Client connected: %s", transport.get_extra_info("peername"))
30
-
31
- def connection_lost(self, exc: Optional[Exception]) -> None:
32
- """Handle connection loss."""
33
- self._connected = False
34
- self.state.satellite = None
35
- _LOGGER.info("Client disconnected")
36
- if exc:
37
- _LOGGER.error("Connection error: %s", exc)
38
-
39
- def data_received(self, data: bytes) -> None:
40
- """Handle incoming data."""
41
- self._buffer.extend(data)
42
-
43
- while len(self._buffer) >= 3:
44
- # Parse message header
45
- msg_type = self._buffer[0]
46
- msg_length = struct.unpack(">H", self._buffer[1:3])[0]
47
-
48
- if len(self._buffer) < 3 + msg_length:
49
- # Need more data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  break
51
 
52
- # Extract message
53
- msg_data = bytes(self._buffer[3 : 3 + msg_length])
54
- self._buffer = self._buffer[3 + msg_length :]
55
-
56
- # Process message
57
- asyncio.create_task(self._process_message(msg_type, msg_data))
58
-
59
- async def _process_message(self, msg_type: int, msg_data: bytes) -> None:
60
- """Process a message."""
61
- try:
62
- if msg_type == 0x01: # Hello
63
- await self._handle_hello(msg_data)
64
- elif msg_type == 0x02: # Voice Assistant Start
65
- await self._handle_voice_assistant_start(msg_data)
66
- elif msg_type == 0x03: # Voice Assistant End
67
- await self._handle_voice_assistant_end(msg_data)
68
- elif msg_type == 0x04: # TTS Audio
69
- await self._handle_tts_audio(msg_data)
70
- else:
71
- _LOGGER.warning("Unknown message type: %s", msg_type)
72
- except Exception as e:
73
- _LOGGER.error("Error processing message: %s", e)
74
-
75
- async def _handle_hello(self, data: bytes) -> None:
76
- """Handle hello message."""
77
- _LOGGER.debug("Received hello message")
78
- # Send hello response
79
- response = self._build_message(0x01, json.dumps({"name": self.state.name}))
80
- self._send_message(response)
81
-
82
- async def _handle_voice_assistant_start(self, data: bytes) -> None:
83
- """Handle voice assistant start message."""
84
- _LOGGER.info("Voice assistant started")
85
- # Play wake sound
86
- try:
87
- with open(self.state.wakeup_sound, "rb") as f:
88
- self.state.tts_player.play(f.read())
89
- except Exception as e:
90
- _LOGGER.error("Error playing wake sound: %s", e)
91
-
92
- async def _handle_voice_assistant_end(self, data: bytes) -> None:
93
- """Handle voice assistant end message."""
94
- _LOGGER.info("Voice assistant ended")
95
-
96
- async def _handle_tts_audio(self, data: bytes) -> None:
97
- """Handle TTS audio message."""
98
- try:
99
- self.state.tts_player.play(data)
100
- except Exception as e:
101
- _LOGGER.error("Error playing TTS audio: %s", e)
102
 
103
  def handle_audio(self, audio_chunk: bytes) -> None:
104
- """Handle audio chunk from microphone."""
105
- if self._connected and self.transport:
106
- # Send audio data to Home Assistant
107
- message = self._build_message(0x10, audio_chunk)
108
- self._send_message(message)
109
-
110
- def wakeup(self, wake_word) -> None:
111
- """Handle wake word detection."""
112
- _LOGGER.info("Wake word detected: %s", wake_word.id)
113
- # Send wake notification to Home Assistant
114
- message = self._build_message(
115
- 0x11, json.dumps({"wake_word": wake_word.wake_word})
 
 
 
 
 
 
116
  )
117
- self._send_message(message)
 
 
118
 
119
  def stop(self) -> None:
120
- """Handle stop word detection."""
121
- _LOGGER.info("Stop word detected")
122
- # Send stop notification to Home Assistant
123
- message = self._build_message(0x12, json.dumps({"action": "stop"}))
124
- self._send_message(message)
125
-
126
- def _build_message(self, msg_type: int, data: Union[str, bytes]) -> bytes:
127
- """Build a message."""
128
- if isinstance(data, str):
129
- data = data.encode("utf-8")
130
- length = len(data)
131
- return bytes([msg_type]) + struct.pack(">H", length) + data
132
-
133
- def _send_message(self, message: bytes) -> None:
134
- """Send a message."""
135
- if self._connected and self.transport:
136
- self.transport.write(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Voice satellite protocol."""
2
 
3
+ import hashlib
 
4
  import logging
5
+ import posixpath
6
+ import shutil
7
+ import time
8
+ from collections.abc import Iterable
9
+ from typing import Dict, Optional, Set, Union
10
+ from urllib.parse import urlparse, urlunparse
11
+ from urllib.request import urlopen
12
 
13
+ # pylint: disable=no-name-in-module
14
+ from aioesphomeapi.api_pb2 import ( # type: ignore[attr-defined]
15
+ DeviceInfoRequest,
16
+ DeviceInfoResponse,
17
+ ListEntitiesDoneResponse,
18
+ ListEntitiesRequest,
19
+ ListEntitiesServicesResponse,
20
+ ListEntitiesServicesArgumentsResponse,
21
+ MediaPlayerCommandRequest,
22
+ SubscribeHomeAssistantStatesRequest,
23
+ VoiceAssistantAnnounceFinished,
24
+ VoiceAssistantAnnounceRequest,
25
+ VoiceAssistantAudio,
26
+ VoiceAssistantConfigurationRequest,
27
+ VoiceAssistantConfigurationResponse,
28
+ VoiceAssistantEventResponse,
29
+ VoiceAssistantExternalWakeWord,
30
+ VoiceAssistantRequest,
31
+ VoiceAssistantSetConfiguration,
32
+ VoiceAssistantTimerEventResponse,
33
+ VoiceAssistantWakeWord,
34
+ )
35
+ from aioesphomeapi.model import (
36
+ VoiceAssistantEventType,
37
+ VoiceAssistantFeature,
38
+ VoiceAssistantTimerEventType,
39
+ )
40
+ from google.protobuf import message
41
+ from pymicro_wakeword import MicroWakeWord
42
+ from pyopen_wakeword import OpenWakeWord
43
+
44
+ from .api_server import APIServer
45
+ from .entity import MediaPlayerEntity
46
+ from .models import AvailableWakeWord, ServerState, WakeWordType
47
 
48
  _LOGGER = logging.getLogger(__name__)
49
 
50
 
51
+ class VoiceSatelliteProtocol(APIServer):
 
52
 
53
+ def __init__(self, state: ServerState) -> None:
54
+ super().__init__(state.name)
 
 
 
 
55
 
56
+ self.state = state
 
 
57
  self.state.satellite = self
58
+
59
+ if self.state.media_player_entity is None:
60
+ self.state.media_player_entity = MediaPlayerEntity(
61
+ server=self,
62
+ key=len(state.entities),
63
+ name="Media Player",
64
+ object_id="reachy_mini_ha_voice_media_player",
65
+ music_player=state.music_player,
66
+ announce_player=state.tts_player,
67
+ )
68
+ self.state.entities.append(self.state.media_player_entity)
69
+
70
+ self._is_streaming_audio = False
71
+ self._tts_url: Optional[str] = None
72
+ self._tts_played = False
73
+ self._continue_conversation = False
74
+ self._timer_finished = False
75
+ self._external_wake_words: Dict[str, VoiceAssistantExternalWakeWord] = {}
76
+
77
+ def handle_voice_event(
78
+ self, event_type: VoiceAssistantEventType, data: Dict[str, str]
79
+ ) -> None:
80
+ _LOGGER.debug("Voice event: type=%s, data=%s", event_type.name, data)
81
+
82
+ if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START:
83
+ self._tts_url = data.get("url")
84
+ self._tts_played = False
85
+ self._continue_conversation = False
86
+ elif event_type in (
87
+ VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END,
88
+ VoiceAssistantEventType.VOICE_ASSISTANT_STT_END,
89
+ ):
90
+ self._is_streaming_audio = False
91
+ elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_PROGRESS:
92
+ if data.get("tts_start_streaming") == "1":
93
+ # Start streaming early
94
+ self.play_tts()
95
+ elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
96
+ if data.get("continue_conversation") == "1":
97
+ self._continue_conversation = True
98
+ elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
99
+ self._tts_url = data.get("url")
100
+ self.play_tts()
101
+ elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END:
102
+ self._is_streaming_audio = False
103
+ if not self._tts_played:
104
+ self._tts_finished()
105
+
106
+ self._tts_played = False
107
+
108
+ # TODO: handle error
109
+
110
+ def handle_timer_event(
111
+ self,
112
+ event_type: VoiceAssistantTimerEventType,
113
+ msg: VoiceAssistantTimerEventResponse,
114
+ ) -> None:
115
+ _LOGGER.debug("Timer event: type=%s", event_type.name)
116
+ if event_type == VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED:
117
+ if not self._timer_finished:
118
+ self.state.active_wake_words.add(self.state.stop_word.id)
119
+ self._timer_finished = True
120
+ self.duck()
121
+ self._play_timer_finished()
122
+
123
+ def handle_message(self, msg: message.Message) -> Iterable[message.Message]:
124
+ if isinstance(msg, VoiceAssistantEventResponse):
125
+ # Pipeline event
126
+ data: Dict[str, str] = {}
127
+ for arg in msg.data:
128
+ data[arg.name] = arg.value
129
+
130
+ self.handle_voice_event(VoiceAssistantEventType(msg.event_type), data)
131
+ elif isinstance(msg, VoiceAssistantAnnounceRequest):
132
+ _LOGGER.debug("Announcing: %s", msg.text)
133
+
134
+ assert self.state.media_player_entity is not None
135
+
136
+ urls = []
137
+ if msg.preannounce_media_id:
138
+ urls.append(msg.preannounce_media_id)
139
+
140
+ urls.append(msg.media_id)
141
+
142
+ self.state.active_wake_words.add(self.state.stop_word.id)
143
+ self._continue_conversation = msg.start_conversation
144
+
145
+ self.duck()
146
+ yield from self.state.media_player_entity.play(
147
+ urls, announcement=True, done_callback=self._tts_finished
148
+ )
149
+ elif isinstance(msg, VoiceAssistantTimerEventResponse):
150
+ self.handle_timer_event(VoiceAssistantTimerEventType(msg.event_type), msg)
151
+ elif isinstance(msg, DeviceInfoRequest):
152
+ yield DeviceInfoResponse(
153
+ uses_password=False,
154
+ name=self.state.name,
155
+ mac_address=self.state.mac_address,
156
+ voice_assistant_feature_flags=(
157
+ VoiceAssistantFeature.VOICE_ASSISTANT
158
+ | VoiceAssistantFeature.API_AUDIO
159
+ | VoiceAssistantFeature.ANNOUNCE
160
+ | VoiceAssistantFeature.START_CONVERSATION
161
+ | VoiceAssistantFeature.TIMERS
162
+ ),
163
+ )
164
+ elif isinstance(
165
+ msg,
166
+ (
167
+ ListEntitiesRequest,
168
+ SubscribeHomeAssistantStatesRequest,
169
+ MediaPlayerCommandRequest,
170
+ ),
171
+ ):
172
+ for entity in self.state.entities:
173
+ yield from entity.handle_message(msg)
174
+
175
+ if isinstance(msg, ListEntitiesRequest):
176
+ yield ListEntitiesDoneResponse()
177
+ elif isinstance(msg, VoiceAssistantConfigurationRequest):
178
+ available_wake_words = [
179
+ VoiceAssistantWakeWord(
180
+ id=ww.id,
181
+ wake_word=ww.wake_word,
182
+ trained_languages=ww.trained_languages,
183
+ )
184
+ for ww in self.state.available_wake_words.values()
185
+ ]
186
+
187
+ for eww in msg.external_wake_words:
188
+ if eww.model_type != "micro":
189
+ continue
190
+
191
+ available_wake_words.append(
192
+ VoiceAssistantWakeWord(
193
+ id=eww.id,
194
+ wake_word=eww.wake_word,
195
+ trained_languages=eww.trained_languages,
196
+ )
197
+ )
198
+
199
+ self._external_wake_words[eww.id] = eww
200
+
201
+ yield VoiceAssistantConfigurationResponse(
202
+ available_wake_words=available_wake_words,
203
+ active_wake_words=[
204
+ ww.id
205
+ for ww in self.state.wake_words.values()
206
+ if ww.id in self.state.active_wake_words
207
+ ],
208
+ max_active_wake_words=2,
209
+ )
210
+ _LOGGER.info("Connected to Home Assistant")
211
+ elif isinstance(msg, VoiceAssistantSetConfiguration):
212
+ # Change active wake words
213
+ active_wake_words: Set[str] = set()
214
+
215
+ for wake_word_id in msg.active_wake_words:
216
+ if wake_word_id in self.state.wake_words:
217
+ # Already active
218
+ active_wake_words.add(wake_word_id)
219
+ continue
220
+
221
+ model_info = self.state.available_wake_words.get(wake_word_id)
222
+ if not model_info:
223
+ # Check external wake words (may require download)
224
+ external_wake_word = self._external_wake_words.get(wake_word_id)
225
+ if not external_wake_word:
226
+ continue
227
+
228
+ model_info = self._download_external_wake_word(external_wake_word)
229
+ if not model_info:
230
+ continue
231
+
232
+ self.state.available_wake_words[wake_word_id] = model_info
233
+
234
+ _LOGGER.debug("Loading wake word: %s", model_info.wake_word_path)
235
+ self.state.wake_words[wake_word_id] = model_info.load()
236
+
237
+ _LOGGER.info("Wake word set: %s", wake_word_id)
238
+ active_wake_words.add(wake_word_id)
239
  break
240
 
241
+ self.state.active_wake_words = active_wake_words
242
+ _LOGGER.debug("Active wake words: %s", active_wake_words)
243
+
244
+ self.state.preferences.active_wake_words = list(active_wake_words)
245
+ self.state.save_preferences()
246
+ self.state.wake_words_changed = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  def handle_audio(self, audio_chunk: bytes) -> None:
249
+
250
+ if not self._is_streaming_audio:
251
+ return
252
+
253
+ self.send_messages([VoiceAssistantAudio(data=audio_chunk)])
254
+
255
+ def wakeup(self, wake_word: Union[MicroWakeWord, OpenWakeWord]) -> None:
256
+ if self._timer_finished:
257
+ # Stop timer instead
258
+ self._timer_finished = False
259
+ self.state.tts_player.stop()
260
+ _LOGGER.debug("Stopping timer finished sound")
261
+ return
262
+
263
+ wake_word_phrase = wake_word.wake_word
264
+ _LOGGER.debug("Detected wake word: %s", wake_word_phrase)
265
+ self.send_messages(
266
+ [VoiceAssistantRequest(start=True, wake_word_phrase=wake_word_phrase)]
267
  )
268
+ self.duck()
269
+ self._is_streaming_audio = True
270
+ self.state.tts_player.play(self.state.wakeup_sound)
271
 
272
  def stop(self) -> None:
273
+ self.state.active_wake_words.discard(self.state.stop_word.id)
274
+ self.state.tts_player.stop()
275
+
276
+ if self._timer_finished:
277
+ self._timer_finished = False
278
+ _LOGGER.debug("Stopping timer finished sound")
279
+ else:
280
+ _LOGGER.debug("TTS response stopped manually")
281
+ self._tts_finished()
282
+
283
+ def play_tts(self) -> None:
284
+ if (not self._tts_url) or self._tts_played:
285
+ return
286
+
287
+ self._tts_played = True
288
+ _LOGGER.debug("Playing TTS response: %s", self._tts_url)
289
+
290
+ self.state.active_wake_words.add(self.state.stop_word.id)
291
+ self.state.tts_player.play(self._tts_url, done_callback=self._tts_finished)
292
+
293
+ def duck(self) -> None:
294
+ _LOGGER.debug("Ducking music")
295
+ self.state.music_player.duck()
296
+
297
+ def unduck(self) -> None:
298
+ _LOGGER.debug("Unducking music")
299
+ self.state.music_player.unduck()
300
+
301
+ def _tts_finished(self) -> None:
302
+ self.state.active_wake_words.discard(self.state.stop_word.id)
303
+ self.send_messages([VoiceAssistantAnnounceFinished()])
304
+
305
+ if self._continue_conversation:
306
+ self.send_messages([VoiceAssistantRequest(start=True)])
307
+ self._is_streaming_audio = True
308
+ _LOGGER.debug("Continuing conversation")
309
+ else:
310
+ self.unduck()
311
+
312
+ _LOGGER.debug("TTS response finished")
313
+
314
+ def _play_timer_finished(self) -> None:
315
+ if not self._timer_finished:
316
+ self.unduck()
317
+ return
318
+
319
+ self.state.tts_player.play(
320
+ self.state.timer_finished_sound,
321
+ done_callback=lambda: time.sleep(1.0) or self._play_timer_finished(),
322
+ )
323
+
324
+ def connection_lost(self, exc):
325
+ super().connection_lost(exc)
326
+ _LOGGER.info("Disconnected from Home Assistant")
327
+
328
+ def _download_external_wake_word(
329
+ self, external_wake_word: VoiceAssistantExternalWakeWord
330
+ ) -> Optional[AvailableWakeWord]:
331
+ eww_dir = self.state.download_dir / "external_wake_words"
332
+ eww_dir.mkdir(parents=True, exist_ok=True)
333
+
334
+ config_path = eww_dir / f"{external_wake_word.id}.json"
335
+ should_download_config = not config_path.exists()
336
+
337
+ # Check if we need to download the model file
338
+ model_path = eww_dir / f"{external_wake_word.id}.tflite"
339
+ should_download_model = True
340
+ if model_path.exists():
341
+ model_size = model_path.stat().st_size
342
+ if model_size == external_wake_word.model_size:
343
+ with open(model_path, "rb") as model_file:
344
+ model_hash = hashlib.sha256(model_file.read()).hexdigest()
345
+
346
+ if model_hash == external_wake_word.model_hash:
347
+ should_download_model = False
348
+ _LOGGER.debug(
349
+ "Model size and hash match for %s. Skipping download.",
350
+ external_wake_word.id,
351
+ )
352
+
353
+ if should_download_config or should_download_model:
354
+ # Download config
355
+ _LOGGER.debug("Downloading %s to %s", external_wake_word.url, config_path)
356
+ with urlopen(external_wake_word.url) as request:
357
+ if request.status != 200:
358
+ _LOGGER.warning(
359
+ "Failed to download: %s, status=%s",
360
+ external_wake_word.url,
361
+ request.status,
362
+ )
363
+ return None
364
+
365
+ with open(config_path, "wb") as model_file:
366
+ shutil.copyfileobj(request, model_file)
367
+
368
+ if should_download_model:
369
+ # Download model file
370
+ parsed_url = urlparse(external_wake_word.url)
371
+ parsed_url = parsed_url._replace(
372
+ path=posixpath.join(posixpath.dirname(parsed_url.path), model_path.name)
373
+ )
374
+ model_url = urlunparse(parsed_url)
375
+
376
+ _LOGGER.debug("Downloading %s to %s", model_url, model_path)
377
+ with urlopen(model_url) as request:
378
+ if request.status != 200:
379
+ _LOGGER.warning(
380
+ "Failed to download: %s, status=%s", model_url, request.status
381
+ )
382
+ return None
383
+
384
+ with open(model_path, "wb") as model_file:
385
+ shutil.copyfileobj(request, model_file)
386
+
387
+ return AvailableWakeWord(
388
+ id=external_wake_word.id,
389
+ type=WakeWordType.MICRO_WAKE_WORD,
390
+ wake_word=external_wake_word.wake_word,
391
+ trained_languages=external_wake_word.trained_languages,
392
+ wake_word_path=config_path,
393
+ )