gio/
write_output_stream.rs

1// Take a look at the license at the top of the repository in the LICENSE file.
2
3use std::{
4    any::Any,
5    io::{Seek, Write},
6};
7
8use crate::{
9    prelude::*, read_input_stream::std_error_to_gio_error, subclass::prelude::*, OutputStream,
10};
11
12mod imp {
13    use std::cell::RefCell;
14
15    use super::*;
16
17    pub(super) enum Writer {
18        Write(AnyWriter),
19        WriteSeek(AnyWriter),
20    }
21
22    #[derive(Default)]
23    pub struct WriteOutputStream {
24        pub(super) write: RefCell<Option<Writer>>,
25    }
26
27    #[glib::object_subclass]
28    impl ObjectSubclass for WriteOutputStream {
29        const NAME: &'static str = "WriteOutputStream";
30        const ALLOW_NAME_CONFLICT: bool = true;
31        type Type = super::WriteOutputStream;
32        type ParentType = OutputStream;
33        type Interfaces = (crate::Seekable,);
34    }
35
36    impl ObjectImpl for WriteOutputStream {}
37
38    impl OutputStreamImpl for WriteOutputStream {
39        fn write(
40            &self,
41            buffer: &[u8],
42            _cancellable: Option<&crate::Cancellable>,
43        ) -> Result<usize, glib::Error> {
44            let mut write = self.write.borrow_mut();
45            let write = match *write {
46                None => {
47                    return Err(glib::Error::new(
48                        crate::IOErrorEnum::Closed,
49                        "Alwritey closed",
50                    ));
51                }
52                Some(Writer::Write(ref mut write)) => write,
53                Some(Writer::WriteSeek(ref mut write)) => write,
54            };
55
56            loop {
57                match std_error_to_gio_error(write.write(buffer)) {
58                    None => continue,
59                    Some(res) => return res,
60                }
61            }
62        }
63
64        fn close(&self, _cancellable: Option<&crate::Cancellable>) -> Result<(), glib::Error> {
65            let _ = self.write.take();
66            Ok(())
67        }
68
69        fn flush(&self, _cancellable: Option<&crate::Cancellable>) -> Result<(), glib::Error> {
70            let mut write = self.write.borrow_mut();
71            let write = match *write {
72                None => {
73                    return Err(glib::Error::new(
74                        crate::IOErrorEnum::Closed,
75                        "Alwritey closed",
76                    ));
77                }
78                Some(Writer::Write(ref mut write)) => write,
79                Some(Writer::WriteSeek(ref mut write)) => write,
80            };
81
82            loop {
83                match std_error_to_gio_error(write.flush()) {
84                    None => continue,
85                    Some(res) => return res,
86                }
87            }
88        }
89    }
90
91    impl SeekableImpl for WriteOutputStream {
92        fn tell(&self) -> i64 {
93            // XXX: stream_position is not stable yet
94            // let mut write = self.write.borrow_mut();
95            // match *write {
96            //     Some(Writer::WriteSeek(ref mut write)) => {
97            //         write.stream_position().map(|pos| pos as i64).unwrap_or(-1)
98            //     },
99            //     _ => -1,
100            // };
101            -1
102        }
103
104        fn can_seek(&self) -> bool {
105            let write = self.write.borrow();
106            matches!(*write, Some(Writer::WriteSeek(_)))
107        }
108
109        fn seek(
110            &self,
111            offset: i64,
112            type_: glib::SeekType,
113            _cancellable: Option<&crate::Cancellable>,
114        ) -> Result<(), glib::Error> {
115            use std::io::SeekFrom;
116
117            let mut write = self.write.borrow_mut();
118            match *write {
119                Some(Writer::WriteSeek(ref mut write)) => {
120                    let pos = match type_ {
121                        glib::SeekType::Cur => SeekFrom::Current(offset),
122                        glib::SeekType::Set => {
123                            if offset < 0 {
124                                return Err(glib::Error::new(
125                                    crate::IOErrorEnum::InvalidArgument,
126                                    "Invalid Argument",
127                                ));
128                            } else {
129                                SeekFrom::Start(offset as u64)
130                            }
131                        }
132                        glib::SeekType::End => SeekFrom::End(offset),
133                        _ => unimplemented!(),
134                    };
135
136                    loop {
137                        match std_error_to_gio_error(write.seek(pos)) {
138                            None => continue,
139                            Some(res) => return res.map(|_| ()),
140                        }
141                    }
142                }
143                _ => Err(glib::Error::new(
144                    crate::IOErrorEnum::NotSupported,
145                    "Truncating not supported",
146                )),
147            }
148        }
149
150        fn can_truncate(&self) -> bool {
151            false
152        }
153
154        fn truncate(
155            &self,
156            _offset: i64,
157            _cancellable: Option<&crate::Cancellable>,
158        ) -> Result<(), glib::Error> {
159            Err(glib::Error::new(
160                crate::IOErrorEnum::NotSupported,
161                "Truncating not supported",
162            ))
163        }
164    }
165}
166
167glib::wrapper! {
168    pub struct WriteOutputStream(ObjectSubclass<imp::WriteOutputStream>) @extends crate::OutputStream, @implements crate::Seekable;
169}
170
171impl WriteOutputStream {
172    pub fn new<W: Write + Send + Any + 'static>(write: W) -> WriteOutputStream {
173        let obj: Self = glib::Object::new();
174
175        *obj.imp().write.borrow_mut() = Some(imp::Writer::Write(AnyWriter::new(write)));
176        obj
177    }
178
179    pub fn new_seekable<W: Write + Seek + Send + Any + 'static>(write: W) -> WriteOutputStream {
180        let obj: Self = glib::Object::new();
181
182        *obj.imp().write.borrow_mut() =
183            Some(imp::Writer::WriteSeek(AnyWriter::new_seekable(write)));
184        obj
185    }
186
187    pub fn close_and_take(&self) -> Box<dyn Any + Send + 'static> {
188        let inner = self.imp().write.take();
189
190        let ret = match inner {
191            None => {
192                panic!("Stream already closed or inner taken");
193            }
194            Some(imp::Writer::Write(write)) => write.writer,
195            Some(imp::Writer::WriteSeek(write)) => write.writer,
196        };
197
198        let _ = self.close(crate::Cancellable::NONE);
199
200        match ret {
201            AnyOrPanic::Any(w) => w,
202            AnyOrPanic::Panic(p) => std::panic::resume_unwind(p),
203        }
204    }
205}
206
207enum AnyOrPanic {
208    Any(Box<dyn Any + Send + 'static>),
209    Panic(Box<dyn Any + Send + 'static>),
210}
211
212// Helper struct for dynamically dispatching to any kind of Writer and
213// catching panics along the way
214struct AnyWriter {
215    writer: AnyOrPanic,
216    write_fn: fn(s: &mut AnyWriter, buffer: &[u8]) -> std::io::Result<usize>,
217    flush_fn: fn(s: &mut AnyWriter) -> std::io::Result<()>,
218    seek_fn: Option<fn(s: &mut AnyWriter, pos: std::io::SeekFrom) -> std::io::Result<u64>>,
219}
220
221impl AnyWriter {
222    fn new<W: Write + Any + Send + 'static>(w: W) -> Self {
223        Self {
224            writer: AnyOrPanic::Any(Box::new(w)),
225            write_fn: Self::write_fn::<W>,
226            flush_fn: Self::flush_fn::<W>,
227            seek_fn: None,
228        }
229    }
230
231    fn new_seekable<W: Write + Seek + Any + Send + 'static>(w: W) -> Self {
232        Self {
233            writer: AnyOrPanic::Any(Box::new(w)),
234            write_fn: Self::write_fn::<W>,
235            flush_fn: Self::flush_fn::<W>,
236            seek_fn: Some(Self::seek_fn::<W>),
237        }
238    }
239
240    fn write_fn<W: Write + 'static>(s: &mut AnyWriter, buffer: &[u8]) -> std::io::Result<usize> {
241        s.with_inner(|w: &mut W| w.write(buffer))
242    }
243
244    fn flush_fn<W: Write + 'static>(s: &mut AnyWriter) -> std::io::Result<()> {
245        s.with_inner(|w: &mut W| w.flush())
246    }
247
248    fn seek_fn<W: Seek + 'static>(
249        s: &mut AnyWriter,
250        pos: std::io::SeekFrom,
251    ) -> std::io::Result<u64> {
252        s.with_inner(|w: &mut W| w.seek(pos))
253    }
254
255    fn with_inner<W: 'static, T, F: FnOnce(&mut W) -> std::io::Result<T>>(
256        &mut self,
257        func: F,
258    ) -> std::io::Result<T> {
259        match self.writer {
260            AnyOrPanic::Any(ref mut writer) => {
261                let w = writer.downcast_mut::<W>().unwrap();
262                match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(w))) {
263                    Ok(res) => res,
264                    Err(panic) => {
265                        self.writer = AnyOrPanic::Panic(panic);
266                        Err(std::io::Error::new(std::io::ErrorKind::Other, "Panicked"))
267                    }
268                }
269            }
270            AnyOrPanic::Panic(_) => Err(std::io::Error::new(
271                std::io::ErrorKind::Other,
272                "Panicked before",
273            )),
274        }
275    }
276
277    fn write(&mut self, buffer: &[u8]) -> std::io::Result<usize> {
278        (self.write_fn)(self, buffer)
279    }
280
281    fn flush(&mut self) -> std::io::Result<()> {
282        (self.flush_fn)(self)
283    }
284
285    fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
286        if let Some(ref seek_fn) = self.seek_fn {
287            seek_fn(self, pos)
288        } else {
289            unreachable!()
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use std::io::Cursor;
297
298    use super::*;
299
300    #[test]
301    fn test_write() {
302        let cursor = Cursor::new(vec![]);
303        let stream = WriteOutputStream::new(cursor);
304
305        assert_eq!(
306            stream.write(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], crate::Cancellable::NONE),
307            Ok(10)
308        );
309
310        let inner = stream.close_and_take();
311        assert!(inner.is::<Cursor<Vec<u8>>>());
312        let inner = inner.downcast_ref::<Cursor<Vec<u8>>>().unwrap();
313        assert_eq!(inner.get_ref(), &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
314    }
315
316    #[test]
317    fn test_write_seek() {
318        let cursor = Cursor::new(vec![]);
319        let stream = WriteOutputStream::new_seekable(cursor);
320
321        assert_eq!(
322            stream.write(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], crate::Cancellable::NONE),
323            Ok(10)
324        );
325
326        assert!(stream.can_seek());
327        assert_eq!(
328            stream.seek(0, glib::SeekType::Set, crate::Cancellable::NONE),
329            Ok(())
330        );
331
332        assert_eq!(
333            stream.write(
334                &[11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
335                crate::Cancellable::NONE
336            ),
337            Ok(10)
338        );
339
340        let inner = stream.close_and_take();
341        assert!(inner.is::<Cursor<Vec<u8>>>());
342        let inner = inner.downcast_ref::<Cursor<Vec<u8>>>().unwrap();
343        assert_eq!(inner.get_ref(), &[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]);
344    }
345}