mysql/conn/
stmt.rs

1// Copyright (c) 2020 rust-mysql-simple contributors
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use crossbeam_utils::atomic::AtomicCell;
10use mysql_common::{io::ParseBuf, packets::StmtPacket, proto::MyDeserialize};
11
12use std::{borrow::Cow, fmt, io, ptr::NonNull, sync::Arc};
13
14use crate::{prelude::*, Column, Result};
15
16#[derive(Debug, Eq, PartialEq)]
17pub(crate) struct InnerStmt {
18    columns: Option<Arc<[Column]>>,
19    /// This cached value overrides the column metadata stored in the `inner` field.
20    ///
21    /// See MARIADB_CLIENT_CACHE_METADATA capability.
22    columns_cache: ColumnCache,
23    params: Option<Arc<[Column]>>,
24    stmt_packet: StmtPacket,
25    connection_id: u32,
26}
27
28impl<'de> MyDeserialize<'de> for InnerStmt {
29    const SIZE: Option<usize> = StmtPacket::SIZE;
30    type Ctx = u32;
31
32    fn deserialize(connection_id: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
33        let stmt_packet = buf.parse(())?;
34
35        Ok(InnerStmt {
36            columns: None,
37            columns_cache: ColumnCache::new(),
38            params: None,
39            stmt_packet,
40            connection_id,
41        })
42    }
43}
44
45impl InnerStmt {
46    pub fn with_params(mut self, params: Option<Vec<Column>>) -> Self {
47        self.params = params.map(Into::into);
48        self
49    }
50
51    pub fn with_columns(mut self, columns: Option<Vec<Column>>) -> Self {
52        self.columns = columns.map(|x| x.into());
53        self
54    }
55
56    pub fn columns(&self) -> Arc<[Column]> {
57        self.columns_cache
58            .get_columns()
59            .or_else(|| self.columns.clone())
60            .unwrap_or_default()
61    }
62
63    pub fn update_columns_metadata(&self, columns: Vec<Column>) {
64        self.columns_cache.set_columns(columns);
65    }
66
67    pub fn params(&self) -> &[Column] {
68        self.params.as_ref().map(AsRef::as_ref).unwrap_or(&[])
69    }
70
71    pub fn id(&self) -> u32 {
72        self.stmt_packet.statement_id()
73    }
74
75    pub const fn connection_id(&self) -> u32 {
76        self.connection_id
77    }
78
79    pub fn num_params(&self) -> u16 {
80        self.stmt_packet.num_params()
81    }
82
83    pub fn num_columns(&self) -> u16 {
84        self.stmt_packet.num_columns()
85    }
86}
87
88#[derive(Debug, Clone, Eq, PartialEq)]
89pub struct Statement {
90    pub(crate) inner: Arc<InnerStmt>,
91    pub(crate) named_params: Option<Vec<Vec<u8>>>,
92}
93
94impl Statement {
95    pub(crate) fn new(inner: Arc<InnerStmt>, named_params: Option<Vec<Vec<u8>>>) -> Self {
96        Self {
97            inner,
98            named_params,
99        }
100    }
101
102    pub fn columns(&self) -> Arc<[Column]> {
103        self.inner.columns()
104    }
105
106    /// Overrides columns metadata for this statement.
107    ///
108    /// See MARIADB_CLIENT_CACHE_METADATA capability.
109    pub(crate) fn update_columns_metadata(&self, columns: Vec<Column>) {
110        self.inner.update_columns_metadata(columns);
111    }
112
113    pub fn params(&self) -> &[Column] {
114        self.inner.params()
115    }
116
117    pub fn id(&self) -> u32 {
118        self.inner.id()
119    }
120
121    pub fn connection_id(&self) -> u32 {
122        self.inner.connection_id()
123    }
124
125    pub fn num_params(&self) -> u16 {
126        self.inner.num_params()
127    }
128
129    pub fn num_columns(&self) -> u16 {
130        self.inner.num_columns()
131    }
132}
133
134impl AsStatement for Statement {
135    fn as_statement<Q: Queryable>(&self, _queryable: &mut Q) -> Result<Cow<'_, Statement>> {
136        Ok(Cow::Borrowed(self))
137    }
138}
139
140impl<'a> AsStatement for &'a Statement {
141    fn as_statement<Q: Queryable>(&self, _queryable: &mut Q) -> Result<Cow<'_, Statement>> {
142        Ok(Cow::Borrowed(self))
143    }
144}
145
146impl<T: AsRef<str>> AsStatement for T {
147    fn as_statement<Q: Queryable>(&self, queryable: &mut Q) -> Result<Cow<'static, Statement>> {
148        let statement = queryable.prep(self.as_ref())?;
149        Ok(Cow::Owned(statement))
150    }
151}
152
153/// This is to make raw Arc pointer Send and Sync
154///
155/// This splits fat `*const [Column]` pointer to its components
156#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
157#[repr(transparent)]
158struct ColumnsArcPtr((NonNull<Column>, usize));
159
160impl ColumnsArcPtr {
161    fn from_arc(arc: Arc<[Column]>) -> Self {
162        let len = arc.len();
163        let ptr = Arc::into_raw(arc);
164        // SAFETY: the `Arc` structure itself contains NonNull so this is either safe
165        // or someone created a broken `Arc` using unsafe code.
166        let ptr = unsafe { NonNull::new_unchecked(ptr as *const Column as *mut Column) };
167        Self((ptr, len))
168    }
169
170    fn to_arc(self) -> Arc<[Column]> {
171        let columns = self.into_arc();
172        let clone = columns.clone();
173        // ignore the pointer because it is already stored in self
174        let _ = Arc::into_raw(columns);
175        clone
176    }
177
178    fn into_arc(self) -> Arc<[Column]> {
179        let fat_pointer = NonNull::slice_from_raw_parts(self.0 .0, self.0 .1);
180        // SAFETY: non-null pointer always points to a valid Arc
181        unsafe { Arc::from_raw(fat_pointer.as_ptr()) }
182    }
183}
184
185unsafe impl Send for ColumnsArcPtr {}
186unsafe impl Sync for ColumnsArcPtr {}
187
188struct ColumnCache {
189    columns: AtomicCell<Option<ColumnsArcPtr>>,
190}
191
192impl ColumnCache {
193    fn new() -> Self {
194        Self {
195            columns: AtomicCell::new(None),
196        }
197    }
198
199    fn get_columns(&self) -> Option<Arc<[Column]>> {
200        self.columns.load().map(|x| x.to_arc())
201    }
202
203    fn set_columns(&self, new_columns: Vec<Column>) {
204        let new_columns: Arc<[Column]> = new_columns.into();
205        let new_ptr = ColumnsArcPtr::from_arc(new_columns);
206
207        let Some(old_ptr) = self.columns.swap(Some(new_ptr)) else {
208            return;
209        };
210
211        // drop the old `Arc`
212        old_ptr.into_arc();
213    }
214}
215
216impl fmt::Debug for ColumnCache {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        f.debug_struct("ColumnCache")
219            .field("columns", &self.get_columns())
220            .finish()
221    }
222}
223
224impl PartialEq for ColumnCache {
225    fn eq(&self, other: &Self) -> bool {
226        self.get_columns() == other.get_columns()
227    }
228}
229
230impl Eq for ColumnCache {}
231
232impl Drop for ColumnCache {
233    fn drop(&mut self) {
234        // drop `Arc` if any
235        self.columns.load().map(|x| x.into_arc());
236    }
237}