moly_kit/widgets/
stt_input.rs

1use crate::aitk::protocol::{Attachment, BotClient, BotId, EntityId, Message, MessageContent};
2use crate::aitk::utils::asynchronous::{AbortOnDropHandle, spawn_abort_on_drop};
3use crate::utils::makepad::events::EventExt;
4use makepad_widgets::defer_with_redraw::DeferWithRedraw;
5use makepad_widgets::*;
6use std::sync::{Arc, Mutex};
7
8#[derive(Clone)]
9pub struct SttUtility {
10    pub client: Box<dyn BotClient>,
11    pub bot_id: BotId,
12}
13
14script_mod! {
15    use mod.prelude.widgets.*
16    use mod.widgets.*
17
18    let HorizontalFiller = View { width: Fill, height: 0 }
19
20    let IconButton = Button {
21        width: Fit, height: Fit
22        padding: Inset { top: 6, bottom: 6, left: 8, right: 8 }
23        draw_text +: {
24            text_style: theme.font_icons {
25                font_size: 12.
26            }
27        }
28        draw_bg +: {
29            border_radius: 8.
30            border_size: 0.
31        }
32    }
33
34    mod.widgets.SttInputBase = #(SttInput::register_widget(vm))
35    mod.widgets.SttInput = set_type_default() do mod.widgets.SttInputBase {
36        flow: Right,
37        height: 50,
38        align: Align { y: 0.5 },
39        spacing: 10,
40        padding: 10,
41        draw_bg +: {
42            color: #fff
43            border_radius: 12,
44            border_color: #8888,
45            border_size: 1.0,
46        }
47
48        cancel := IconButton {
49            text: "\u{f00d}"
50            draw_text +: {
51                color: #000,
52                color_hover: #000,
53                color_down: #000,
54                color_focus: #000,
55            }
56            draw_bg +: {
57                color: #x0000
58                color_hover: #x0000
59                color_down: #x0000
60                color_focus: #x0000
61            }
62        }
63        HorizontalFiller {}
64        status := Label {
65            text: "Recording...",
66            draw_text +: {
67                color: #000,
68                text_style +: { font_size: 11 }
69            }
70        }
71        HorizontalFiller {}
72        confirm := IconButton {
73            text: "\u{f00c}"
74            draw_text +: {
75                color: #fff,
76                color_hover: #fff,
77                color_down: #fff,
78                color_focus: #fff,
79            }
80            draw_bg +: {
81                color: #000
82                color_hover: #000
83                color_down: #000
84                color_focus: #000
85            }
86        }
87    }
88}
89
90#[derive(Clone, Debug, Default)]
91struct AudioData {
92    pub data: Vec<f32>,
93    pub sample_rate: Option<f64>,
94}
95
96#[derive(Clone, Debug, Default)]
97pub enum SttInputAction {
98    Transcribed(String),
99    Cancelled,
100    #[default]
101    None,
102}
103
104#[derive(PartialEq, Clone, Debug, Default)]
105enum SttInputState {
106    #[default]
107    Idle,
108    Recording(RecordingState),
109    Sending,
110}
111
112#[derive(PartialEq, Clone, Debug)]
113struct RecordingState {
114    start_time: f64,
115}
116
117const TIMER_PRECISION: f64 = 0.1;
118
119#[derive(Script, Widget, ScriptHook)]
120pub struct SttInput {
121    #[deref]
122    pub deref: View,
123    #[source]
124    source: ScriptObjectRef,
125
126    #[rust]
127    state: SttInputState,
128
129    #[rust]
130    stt_utility: Option<SttUtility>,
131
132    #[rust]
133    audio_buffer: Option<Arc<Mutex<AudioData>>>,
134
135    #[rust]
136    abort_handle: Option<AbortOnDropHandle>,
137
138    #[rust]
139    timer: Timer,
140}
141
142impl Widget for SttInput {
143    fn draw_walk(&mut self, cx: &mut Cx2d, scope: &mut Scope, walk: Walk) -> DrawStep {
144        self.deref.draw_walk(cx, scope, walk)
145    }
146
147    fn handle_event(&mut self, cx: &mut Cx, event: &Event, scope: &mut Scope) {
148        self.ui_runner().handle(cx, event, scope, self);
149        self.deref.handle_event(cx, event, scope);
150
151        if self.timer.is_event(event).is_some() {
152            if let SttInputState::Recording(recording_state) = &self.state {
153                let elapsed = Cx::time_now() - recording_state.start_time;
154                self.label(cx, ids!(status))
155                    .set_text(cx, &time_to_minutes_seconds(elapsed));
156                self.timer = cx.start_timeout(TIMER_PRECISION);
157            }
158        }
159
160        if self.button(cx, ids!(confirm)).clicked(event.actions()) {
161            self.finish_recording(cx, scope);
162        }
163
164        if self.button(cx, ids!(cancel)).clicked(event.actions()) {
165            self.cancel_recording(cx, scope);
166        }
167    }
168}
169
170impl SttInput {
171    /// Sets the STT utility to be used for transcription.
172    pub fn set_stt_utility(&mut self, utility: Option<SttUtility>) {
173        self.stt_utility = utility;
174    }
175
176    /// Getter for the current STT utility.
177    pub fn stt_utility(&self) -> Option<&SttUtility> {
178        self.stt_utility.as_ref()
179    }
180
181    /// Begins recording audio from the microphone.
182    pub fn start_recording(&mut self, cx: &mut Cx) {
183        self.button(cx, ids!(confirm)).set_visible(cx, true);
184
185        self.state = SttInputState::Recording(RecordingState {
186            start_time: Cx::time_now(),
187        });
188        self.label(cx, ids!(status))
189            .set_text(cx, &time_to_minutes_seconds(0.));
190        self.timer = cx.start_timeout(TIMER_PRECISION);
191
192        if self.audio_buffer.is_none() {
193            self.audio_buffer = Some(Arc::new(Mutex::new(AudioData::default())));
194        }
195
196        if let Some(arc) = &self.audio_buffer {
197            if let Ok(mut buffer) = arc.lock() {
198                buffer.data.clear();
199                buffer.sample_rate = None;
200            }
201
202            let buffer_clone = arc.clone();
203            cx.audio_input(0, move |info, input_buffer| {
204                let channel = input_buffer.channel(0);
205
206                if let Ok(mut recorded) = buffer_clone.try_lock() {
207                    if recorded.sample_rate.is_none() {
208                        recorded.sample_rate = Some(info.sample_rate);
209                    }
210                    recorded.data.extend_from_slice(channel);
211                }
212            });
213        }
214    }
215
216    fn stop_recording(&mut self, cx: &mut Cx) {
217        cx.audio_input(0, |_, _| {});
218    }
219
220    /// Completes the recording and starts the transcription process.
221    pub fn finish_recording(&mut self, cx: &mut Cx, scope: &mut Scope) {
222        self.stop_recording(cx);
223        self.state = SttInputState::Sending;
224        self.label(cx, ids!(status)).set_text(cx, "Transcribing...");
225        self.button(cx, ids!(confirm)).set_visible(cx, false);
226
227        if let Some(buffer_arc) = self.audio_buffer.clone() {
228            self.process_stt_audio(cx, buffer_arc, scope);
229        }
230    }
231
232    /// Cancels the ongoing recording or transcription.
233    ///
234    /// This stops the audio device and aborts the async transcription
235    /// request.
236    pub fn cancel_recording(&mut self, cx: &mut Cx, _scope: &mut Scope) {
237        self.stop_recording(cx);
238        self.state = SttInputState::Idle;
239        self.abort_handle = None;
240
241        let uid = self.widget_uid();
242        cx.widget_action(uid, SttInputAction::Cancelled);
243    }
244
245    fn process_stt_audio(
246        &mut self,
247        cx: &mut Cx,
248        buffer_arc: Arc<Mutex<AudioData>>,
249        scope: &mut Scope,
250    ) {
251        if let Some(utility) = &self.stt_utility {
252            let mut client = utility.client.clone();
253            let bot_id = utility.bot_id.clone();
254            let ui = self.ui_runner();
255
256            let (samples, sample_rate) = {
257                let guard = buffer_arc.lock().unwrap();
258                (guard.data.clone(), guard.sample_rate)
259            };
260
261            if samples.is_empty() {
262                self.cancel_recording(cx, scope);
263                return;
264            }
265
266            let sample_rate = sample_rate.unwrap_or(48000.0) as u32;
267            let wav_bytes = match crate::utils::audio::build_wav(&samples, sample_rate, 1) {
268                Ok(bytes) => bytes,
269                Err(e) => {
270                    ::log::error!("Error encoding audio: {}", e);
271                    self.cancel_recording(cx, scope);
272                    return;
273                }
274            };
275
276            let attachment = Attachment::from_bytes(
277                "recording.wav".to_string(),
278                Some("audio/wav".to_string()),
279                &wav_bytes,
280            );
281
282            let message = Message {
283                from: EntityId::User,
284                content: MessageContent {
285                    attachments: vec![attachment],
286                    ..Default::default()
287                },
288                ..Default::default()
289            };
290
291            let future = async move {
292                use futures::{StreamExt, pin_mut};
293                let stream = client.send(&bot_id, &[message], &[]);
294
295                let filtered = stream
296                    .filter_map(|r| async move { r.value().map(|c| c.text.clone()) })
297                    .filter(|text| futures::future::ready(!text.is_empty()));
298                pin_mut!(filtered);
299                let text = filtered.next().await;
300
301                if let Some(text) = text {
302                    ui.defer_with_redraw(move |me: &mut SttInput, cx, scope| {
303                        me.handle_transcription(cx, text, scope);
304                    });
305                } else {
306                    ui.defer_with_redraw(move |me: &mut SttInput, cx, scope| {
307                        me.cancel_recording(cx, scope);
308                    });
309                }
310            };
311
312            self.abort_handle = Some(spawn_abort_on_drop(future));
313        }
314    }
315
316    fn handle_transcription(&mut self, cx: &mut Cx, text: String, _scope: &mut Scope) {
317        self.state = SttInputState::Idle;
318        self.abort_handle = None;
319        let uid = self.widget_uid();
320        cx.widget_action(uid, SttInputAction::Transcribed(text));
321    }
322
323    /// When the transcription is ready, read it from the actions.
324    pub fn transcribed(&self, actions: &Actions) -> Option<String> {
325        actions
326            .find_widget_action(self.widget_uid())
327            .map(|wa| wa.cast::<SttInputAction>())
328            .and_then(|action| match action {
329                SttInputAction::Transcribed(text) => Some(text),
330                _ => None,
331            })
332    }
333
334    /// Check if the transcription was cancelled.
335    pub fn cancelled(&self, actions: &Actions) -> bool {
336        actions
337            .find_widget_action(self.widget_uid())
338            .map(|wa| wa.cast::<SttInputAction>())
339            .map_or(false, |action| matches!(action, SttInputAction::Cancelled))
340    }
341}
342
343impl SttInputRef {
344    /// Immutable access to the underlying [`SttInput`].
345    ///
346    /// Panics if the widget reference is empty or if it's already
347    /// borrowed.
348    pub fn read(&self) -> std::cell::Ref<'_, SttInput> {
349        self.borrow().unwrap()
350    }
351
352    /// Mutable access to the underlying [`SttInput`].
353    ///
354    /// Panics if the widget reference is empty or if it's already
355    /// borrowed.
356    pub fn write(&mut self) -> std::cell::RefMut<'_, SttInput> {
357        self.borrow_mut().unwrap()
358    }
359}
360
361fn time_to_minutes_seconds(time_secs: f64) -> String {
362    let total_seconds = time_secs.floor() as u64;
363    let minutes = total_seconds / 60;
364    let seconds = total_seconds % 60;
365    format!("{}:{:02}", minutes, seconds)
366}
367
368// TODO: We should stop recording on widget drop.