mysql/conn/
mod.rs

1// Copyright (c) 2020 rust-mysql-simple contributors
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use bytes::{Buf, BufMut};
10#[cfg(feature = "binlog")]
11use mysql_common::packets::binlog_request::BinlogRequest;
12use mysql_common::{
13    constants::UTF8MB4_GENERAL_CI,
14    crypto,
15    io::{ParseBuf, ReadMysqlExt},
16    named_params::ParsedNamedParams,
17    packets::{
18        AuthPlugin, AuthSwitchRequest, Column, ComChangeUser, ComChangeUserMoreData, ComStmtClose,
19        ComStmtExecuteRequestBuilder, ComStmtSendLongData, CommonOkPacket, ErrPacket,
20        HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OkPacketKind,
21        OldAuthSwitchRequest, OldEofPacket, ResultSetTerminator, SessionStateInfo,
22    },
23    proto::{codec::Compression, sync_framed::MySyncFramed, MySerialize},
24};
25
26use mysql_common::{
27    constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8_GENERAL_CI},
28    packets::SslRequest,
29};
30
31use std::{
32    borrow::{Borrow, Cow},
33    collections::HashMap,
34    convert::TryFrom,
35    io::{self, Write as _},
36    mem,
37    ops::{Deref, DerefMut},
38    process,
39    sync::Arc,
40};
41
42#[cfg(unix)]
43use std::os::unix::io::{AsRawFd, RawFd};
44
45use crate::{
46    buffer_pool::{get_buffer, Buffer},
47    conn::{
48        local_infile::LocalInfile,
49        pool::{Pool, PooledConn},
50        query_result::{Binary, ResultSetMeta, Text},
51        stmt::{InnerStmt, Statement},
52        stmt_cache::StmtCache,
53        transaction::{AccessMode, TxOpts},
54    },
55    consts::{CapabilityFlags, Command, MariadbCapabilities, StatusFlags, MAX_PAYLOAD_LEN},
56    from_value, from_value_opt,
57    io::Stream,
58    prelude::*,
59    ChangeUserOpts,
60    DriverError::{
61        CleartextPluginDisabled, MismatchedStmtParams, NamedParamsForPositionalQuery,
62        OldMysqlPasswordDisabled, Protocol41NotSet, ReadOnlyTransNotSupported, SetupError,
63        UnexpectedPacket, UnknownAuthPlugin, UnsupportedProtocol,
64    },
65    Error::{self, DriverError, MySqlError},
66    LocalInfileHandler, Opts, OptsBuilder, Params, QueryResult, Result, Transaction,
67    Value::{self, Bytes, NULL},
68};
69
70use crate::DriverError::TlsNotSupported;
71use crate::SslOpts;
72
73#[cfg(feature = "binlog")]
74use self::binlog_stream::BinlogStream;
75
76#[cfg(feature = "binlog")]
77pub mod binlog_stream;
78pub mod local_infile;
79pub mod opts;
80pub mod pool;
81pub mod query;
82pub mod query_result;
83pub mod queryable;
84pub mod stmt;
85mod stmt_cache;
86pub mod transaction;
87
88/// Mutable connection.
89#[derive(Debug)]
90pub enum ConnMut<'c, 't, 'tc> {
91    Mut(&'c mut Conn),
92    TxMut(&'t mut Transaction<'tc>),
93    Owned(Conn),
94    Pooled(PooledConn),
95}
96
97impl From<Conn> for ConnMut<'static, 'static, 'static> {
98    fn from(conn: Conn) -> Self {
99        ConnMut::Owned(conn)
100    }
101}
102
103impl From<PooledConn> for ConnMut<'static, 'static, 'static> {
104    fn from(conn: PooledConn) -> Self {
105        ConnMut::Pooled(conn)
106    }
107}
108
109impl<'a> From<&'a mut Conn> for ConnMut<'a, 'static, 'static> {
110    fn from(conn: &'a mut Conn) -> Self {
111        ConnMut::Mut(conn)
112    }
113}
114
115impl<'a> From<&'a mut PooledConn> for ConnMut<'a, 'static, 'static> {
116    fn from(conn: &'a mut PooledConn) -> Self {
117        ConnMut::Mut(conn.as_mut())
118    }
119}
120
121impl<'t, 'tc> From<&'t mut Transaction<'tc>> for ConnMut<'static, 't, 'tc> {
122    fn from(tx: &'t mut Transaction<'tc>) -> Self {
123        ConnMut::TxMut(tx)
124    }
125}
126
127impl TryFrom<&Pool> for ConnMut<'static, 'static, 'static> {
128    type Error = Error;
129
130    fn try_from(pool: &Pool) -> Result<Self> {
131        pool.get_conn().map(From::from)
132    }
133}
134
135impl Deref for ConnMut<'_, '_, '_> {
136    type Target = Conn;
137
138    fn deref(&self) -> &Conn {
139        match self {
140            ConnMut::Mut(conn) => conn,
141            ConnMut::TxMut(tx) => &tx.conn,
142            ConnMut::Owned(conn) => conn,
143            ConnMut::Pooled(conn) => conn.as_ref(),
144        }
145    }
146}
147
148impl DerefMut for ConnMut<'_, '_, '_> {
149    fn deref_mut(&mut self) -> &mut Conn {
150        match self {
151            ConnMut::Mut(conn) => conn,
152            ConnMut::TxMut(tx) => &mut tx.conn,
153            ConnMut::Owned(ref mut conn) => conn,
154            ConnMut::Pooled(ref mut conn) => conn.as_mut(),
155        }
156    }
157}
158
159pub(crate) enum ResultSetInfo {
160    Empty(OkPacket<'static>),
161    NonEmptyWithMeta(Vec<Column>),
162    // For MariaDB via MARIADB_CLIENT_CACHE_METADATA
163    NonEmptySkipMeta,
164}
165
166impl ResultSetInfo {
167    pub(crate) fn into_query_meta(self) -> ResultSetMeta {
168        match self {
169            ResultSetInfo::Empty(ok_packet) => ResultSetMeta::Empty(ok_packet),
170            ResultSetInfo::NonEmptyWithMeta(columns) => {
171                ResultSetMeta::NonEmptyWithMeta(columns.into())
172            }
173            ResultSetInfo::NonEmptySkipMeta => {
174                // TODO: Server misbehavior — emit runtime error
175                ResultSetMeta::NonEmptyWithMeta(Default::default())
176            }
177        }
178    }
179
180    pub(crate) fn into_statement_meta(self, conn: &Conn, stmt: &Statement) -> ResultSetMeta {
181        match self {
182            ResultSetInfo::Empty(ok_packet) => ResultSetMeta::Empty(ok_packet),
183            ResultSetInfo::NonEmptyWithMeta(columns) => {
184                stmt.update_columns_metadata(columns);
185                ResultSetMeta::NonEmptyWithMeta(stmt.columns())
186            }
187            ResultSetInfo::NonEmptySkipMeta => {
188                assert!(
189                    conn.has_mariadb_capability(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA),
190                    "metadata skipped but no MARIADB_CLIENT_CACHE_METADATA capability negotiated"
191                );
192                ResultSetMeta::NonEmptyWithMeta(stmt.columns())
193            }
194        }
195    }
196}
197
198/// Connection internals.
199#[derive(Debug)]
200struct ConnInner {
201    opts: Opts,
202    stream: Option<MySyncFramed<Stream>>,
203    stmt_cache: StmtCache,
204
205    // TODO: clean this up
206    server_version: Option<(u16, u16, u16)>,
207    mariadb_server_version: Option<(u16, u16, u16)>,
208
209    /// Last Ok packet, if any.
210    ok_packet: Option<OkPacket<'static>>,
211    capability_flags: CapabilityFlags,
212    mariadb_ext_capabilities: MariadbCapabilities,
213    connection_id: u32,
214    status_flags: StatusFlags,
215    character_set: u8,
216    last_command: u8,
217    connected: bool,
218    has_results: bool,
219    local_infile_handler: Option<LocalInfileHandler>,
220
221    auth_plugin: AuthPlugin<'static>,
222    nonce: Vec<u8>,
223
224    /// This flag is to opt-in/opt-out from reset upon return to a pool.
225    pub(crate) reset_upon_return: bool,
226}
227
228impl ConnInner {
229    fn empty(opts: Opts) -> Self {
230        ConnInner {
231            stmt_cache: StmtCache::new(opts.get_stmt_cache_size()),
232            stream: None,
233            capability_flags: CapabilityFlags::empty(),
234            status_flags: StatusFlags::empty(),
235            connection_id: 0u32,
236            character_set: 0u8,
237            ok_packet: None,
238            last_command: 0u8,
239            connected: false,
240            has_results: false,
241            server_version: None,
242            mariadb_server_version: None,
243            mariadb_ext_capabilities: MariadbCapabilities::empty(),
244            local_infile_handler: None,
245            auth_plugin: AuthPlugin::MysqlNativePassword,
246            nonce: Vec::new(),
247            reset_upon_return: opts.get_pool_opts().reset_connection(),
248
249            opts,
250        }
251    }
252}
253
254/// Mysql connection.
255#[derive(Debug)]
256pub struct Conn(Box<ConnInner>);
257
258impl Conn {
259    /// Must not be called before handle_handshake.
260    const fn has_capability(&self, flag: CapabilityFlags) -> bool {
261        self.0.capability_flags.contains(flag)
262    }
263
264    /// Must not be called before handle_handshake.
265    const fn has_mariadb_capability(&self, flag: MariadbCapabilities) -> bool {
266        self.0.mariadb_ext_capabilities.contains(flag)
267    }
268
269    /// Returns version number reported by the server.
270    pub fn server_version(&self) -> (u16, u16, u16) {
271        self.0
272            .server_version
273            .or(self.0.mariadb_server_version)
274            .unwrap()
275    }
276
277    /// Returns connection identifier.
278    pub fn connection_id(&self) -> u32 {
279        self.0.connection_id
280    }
281
282    /// Returns number of rows affected by the last query.
283    pub fn affected_rows(&self) -> u64 {
284        self.0
285            .ok_packet
286            .as_ref()
287            .map(OkPacket::affected_rows)
288            .unwrap_or_default()
289    }
290
291    /// Returns last insert id of the last query.
292    ///
293    /// Returns zero if there was no last insert id.
294    pub fn last_insert_id(&self) -> u64 {
295        self.0
296            .ok_packet
297            .as_ref()
298            .and_then(OkPacket::last_insert_id)
299            .unwrap_or_default()
300    }
301
302    /// Returns number of warnings, reported by the server.
303    pub fn warnings(&self) -> u16 {
304        self.0
305            .ok_packet
306            .as_ref()
307            .map(OkPacket::warnings)
308            .unwrap_or_default()
309    }
310
311    /// [Info], reported by the server.
312    ///
313    /// Will be empty if not defined.
314    ///
315    /// [Info]: http://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
316    pub fn info_ref(&self) -> &[u8] {
317        self.0
318            .ok_packet
319            .as_ref()
320            .and_then(OkPacket::info_ref)
321            .unwrap_or_default()
322    }
323
324    /// [Info], reported by the server.
325    ///
326    /// Will be empty if not defined.
327    ///
328    /// [Info]: http://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
329    pub fn info_str(&self) -> Cow<str> {
330        self.0
331            .ok_packet
332            .as_ref()
333            .and_then(OkPacket::info_str)
334            .unwrap_or_default()
335    }
336
337    pub fn session_state_changes(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
338        self.0
339            .ok_packet
340            .as_ref()
341            .map(|ok| ok.session_state_info())
342            .transpose()
343            .map(Option::unwrap_or_default)
344    }
345
346    fn stream_ref(&self) -> &MySyncFramed<Stream> {
347        self.0.stream.as_ref().expect("incomplete connection")
348    }
349
350    fn stream_mut(&mut self) -> &mut MySyncFramed<Stream> {
351        self.0.stream.as_mut().expect("incomplete connection")
352    }
353
354    fn is_insecure(&self) -> bool {
355        self.stream_ref().get_ref().is_insecure()
356    }
357
358    fn is_socket(&self) -> bool {
359        self.stream_ref().get_ref().is_socket()
360    }
361
362    /// Check the connection can be improved.
363    #[allow(unused_assignments)]
364    fn can_improved(&mut self) -> Result<Option<Opts>> {
365        if self.0.opts.get_prefer_socket() && self.0.opts.addr_is_loopback() {
366            let mut socket = None;
367            #[cfg(test)]
368            {
369                socket = self.0.opts.0.injected_socket.clone();
370            }
371            if socket.is_none() {
372                socket = self.get_system_var("socket")?.map(from_value::<String>);
373            }
374            if let Some(socket) = socket {
375                if self.0.opts.get_socket().is_none() {
376                    let socket_opts = OptsBuilder::from_opts(self.0.opts.clone());
377                    if !socket.is_empty() {
378                        return Ok(Some(socket_opts.socket(Some(socket)).into()));
379                    }
380                }
381            }
382        }
383        Ok(None)
384    }
385
386    /// Creates new `Conn`.
387    pub fn new<T, E>(opts: T) -> Result<Conn>
388    where
389        Opts: TryFrom<T, Error = E>,
390        crate::Error: From<E>,
391    {
392        let opts = Opts::try_from(opts)?;
393        let mut conn = Conn(Box::new(ConnInner::empty(opts)));
394        conn.connect_stream()?;
395        conn.connect()?;
396        let mut conn = {
397            if let Some(new_opts) = conn.can_improved()? {
398                let mut improved_conn = Conn(Box::new(ConnInner::empty(new_opts)));
399                improved_conn
400                    .connect_stream()
401                    .and_then(|_| {
402                        improved_conn.connect()?;
403                        Ok(improved_conn)
404                    })
405                    .unwrap_or(conn)
406            } else {
407                conn
408            }
409        };
410        for cmd in conn.0.opts.get_init() {
411            conn.query_drop(cmd)?;
412        }
413        Ok(conn)
414    }
415
416    fn exec_com_reset_connection(&mut self) -> Result<()> {
417        self.write_command(Command::COM_RESET_CONNECTION, &[])?;
418        let packet = self.read_packet()?;
419        self.handle_ok::<CommonOkPacket>(&packet)?;
420        self.0.last_command = 0;
421        self.0.stmt_cache.clear();
422        Ok(())
423    }
424
425    fn exec_com_change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
426        opts.update_opts(&mut self.0.opts);
427        let com_change_user = ComChangeUser::new()
428            .with_user(self.0.opts.get_user().map(|x| x.as_bytes()))
429            .with_database(self.0.opts.get_db_name().map(|x| x.as_bytes()))
430            .with_auth_plugin_data(
431                self.0
432                    .auth_plugin
433                    .gen_data(self.0.opts.get_pass(), &self.0.nonce)
434                    .as_deref(),
435            )
436            .with_more_data(Some(
437                ComChangeUserMoreData::new(if self.server_version() >= (5, 5, 3) {
438                    UTF8MB4_GENERAL_CI
439                } else {
440                    UTF8_GENERAL_CI
441                })
442                .with_auth_plugin(Some(self.0.auth_plugin.clone()))
443                .with_connect_attributes(self.0.opts.get_connect_attrs().cloned()),
444            ))
445            .into_owned();
446        self.write_command_raw(&com_change_user)?;
447        self.0.last_command = 0;
448        self.0.stmt_cache.clear();
449        self.continue_auth(false)
450    }
451
452    /// Tries to reset the connection.
453    ///
454    /// This function will try to invoke COM_RESET_CONNECTION with
455    /// a fall back to COM_CHANGE_USER on older servers.
456    ///
457    /// ## Warning
458    ///
459    /// There is a long-standing bug in mysql 5.6 that kills this functionality in presence
460    /// of connection attributes (see [Bug #92954](https://bugs.mysql.com/bug.php?id=92954)).
461    ///
462    /// ## Note
463    ///
464    /// Re-executes [`Opts::get_init`].
465    pub fn reset(&mut self) -> Result<()> {
466        let reset_result = match (self.0.server_version, self.0.mariadb_server_version) {
467            (Some(ref version), _) if *version > (5, 7, 3) => self.exec_com_reset_connection(),
468            (_, Some(ref version)) if *version >= (10, 2, 7) => self.exec_com_reset_connection(),
469            _ => return self.exec_com_change_user(ChangeUserOpts::DEFAULT),
470        };
471
472        match reset_result {
473            Ok(_) => (),
474            Err(crate::Error::MySqlError(_)) => {
475                // fallback to COM_CHANGE_USER if server reports an error for COM_RESET_CONNECTION
476                self.exec_com_change_user(ChangeUserOpts::DEFAULT)?;
477            }
478            Err(e) => return Err(e),
479        }
480
481        for cmd in self.0.opts.get_init() {
482            self.query_drop(cmd)?;
483        }
484
485        Ok(())
486    }
487
488    /// Executes [`COM_CHANGE_USER`][1].
489    ///
490    /// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that
491    /// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4).
492    ///
493    /// ## Note
494    ///
495    /// * Using non-default `opts` for a pooled connection is discouraging.
496    /// * Connection options will be updated permanently.
497    ///
498    /// ## Warning
499    ///
500    /// There is a long-standing bug in mysql 5.6 that kills this functionality in presence
501    /// of connection attributes (see [Bug #92954](https://bugs.mysql.com/bug.php?id=92954)).
502    ///
503    /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html
504    pub fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
505        self.exec_com_change_user(opts)
506    }
507
508    fn switch_to_ssl(&mut self, ssl_opts: SslOpts) -> Result<()> {
509        let stream = self.0.stream.take().expect("incomplete conn");
510        let (in_buf, out_buf, codec, stream) = stream.destruct();
511        let stream = stream.make_secure(self.0.opts.get_host(), ssl_opts)?;
512        let stream = MySyncFramed::construct(in_buf, out_buf, codec, stream);
513        self.0.stream = Some(stream);
514        Ok(())
515    }
516
517    fn connect_stream(&mut self) -> Result<()> {
518        let opts = &self.0.opts;
519        let read_timeout = opts.get_read_timeout().cloned();
520        let write_timeout = opts.get_write_timeout().cloned();
521        let tcp_keepalive_time = opts.get_tcp_keepalive_time_ms();
522        #[cfg(any(target_os = "linux", target_os = "macos",))]
523        let tcp_keepalive_probe_interval_secs = opts.get_tcp_keepalive_probe_interval_secs();
524        #[cfg(any(target_os = "linux", target_os = "macos",))]
525        let tcp_keepalive_probe_count = opts.get_tcp_keepalive_probe_count();
526        #[cfg(target_os = "linux")]
527        let tcp_user_timeout = opts.get_tcp_user_timeout_ms();
528        let tcp_nodelay = opts.get_tcp_nodelay();
529        let tcp_connect_timeout = opts.get_tcp_connect_timeout();
530        let bind_address = opts.bind_address().cloned();
531        let stream = if let Some(socket) = opts.get_socket() {
532            Stream::connect_socket(socket, read_timeout, write_timeout)?
533        } else {
534            let port = opts.get_tcp_port();
535            let ip_or_hostname = match opts.get_host() {
536                url::Host::Domain(domain) => domain,
537                url::Host::Ipv4(ip) => ip.to_string(),
538                url::Host::Ipv6(ip) => ip.to_string(),
539            };
540            Stream::connect_tcp(
541                &ip_or_hostname,
542                port,
543                read_timeout,
544                write_timeout,
545                tcp_keepalive_time,
546                #[cfg(any(target_os = "linux", target_os = "macos",))]
547                tcp_keepalive_probe_interval_secs,
548                #[cfg(any(target_os = "linux", target_os = "macos",))]
549                tcp_keepalive_probe_count,
550                #[cfg(target_os = "linux")]
551                tcp_user_timeout,
552                tcp_nodelay,
553                tcp_connect_timeout,
554                bind_address,
555            )?
556        };
557        self.0.stream = Some(MySyncFramed::new(stream));
558        Ok(())
559    }
560
561    fn raw_read_packet(&mut self, buffer: &mut Vec<u8>) -> Result<()> {
562        if !self.stream_mut().next_packet(buffer)? {
563            Err(Error::server_disconnected())
564        } else {
565            Ok(())
566        }
567    }
568
569    fn read_packet(&mut self) -> Result<Buffer> {
570        loop {
571            let mut buffer = get_buffer();
572            match self.raw_read_packet(buffer.as_mut()) {
573                Ok(()) if buffer.first() == Some(&0xff) => {
574                    match ParseBuf(&buffer).parse(self.0.capability_flags)? {
575                        ErrPacket::Error(server_error) => {
576                            self.handle_err();
577                            return Err(MySqlError(From::from(server_error)));
578                        }
579                        ErrPacket::Progress(_progress_report) => {
580                            // TODO: Report progress
581                            continue;
582                        }
583                    }
584                }
585                Ok(()) => return Ok(buffer),
586                Err(e) => {
587                    self.handle_err();
588                    return Err(e);
589                }
590            }
591        }
592    }
593
594    fn drop_packet(&mut self) -> Result<()> {
595        self.read_packet().map(drop)
596    }
597
598    fn write_struct<T: MySerialize>(&mut self, s: &T) -> Result<()> {
599        let mut buf = get_buffer();
600        s.serialize(buf.as_mut());
601        self.write_packet(&mut &*buf)
602    }
603
604    fn write_packet<T: Buf>(&mut self, data: &mut T) -> Result<()> {
605        self.stream_mut().send(data)?;
606        Ok(())
607    }
608
609    fn handle_handshake(&mut self, hp: &HandshakePacket<'_>) {
610        self.0.capability_flags = hp.capabilities() & self.get_client_flags();
611        self.0.status_flags = hp.status_flags();
612        self.0.connection_id = hp.connection_id();
613        self.0.character_set = hp.default_collation();
614        self.0.server_version = hp.server_version_parsed();
615        self.0.mariadb_server_version = hp.maria_db_server_version_parsed();
616        // If we have a MariaDB server version, we are using mariadb extended capbilities from the handshake packet.
617        // MariaDB does not set 1 standard capability flag bit to indicate that it supports extended capabilities.
618        if self.0.mariadb_server_version.is_some()
619            && !self
620                .0
621                .capability_flags
622                .contains(CapabilityFlags::CLIENT_LONG_PASSWORD)
623        {
624            self.0.mariadb_ext_capabilities =
625                hp.mariadb_ext_capabilities() & self.get_mariadb_client_flags();
626        }
627    }
628
629    fn handle_ok<'a, T: OkPacketKind>(
630        &mut self,
631        buffer: &'a Buffer,
632    ) -> crate::Result<OkPacket<'a>> {
633        let ok = ParseBuf(buffer)
634            .parse::<OkPacketDeserializer<T>>(self.0.capability_flags)?
635            .into_inner();
636        self.0.status_flags = ok.status_flags();
637        self.0.ok_packet = Some(ok.clone().into_owned());
638        Ok(ok)
639    }
640
641    fn handle_err(&mut self) {
642        self.0.status_flags = StatusFlags::empty();
643        self.0.has_results = false;
644        self.0.ok_packet = None;
645    }
646
647    fn more_results_exists(&self) -> bool {
648        self.0
649            .status_flags
650            .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
651    }
652
653    fn perform_auth_switch(&mut self, auth_switch_request: AuthSwitchRequest<'_>) -> Result<()> {
654        if matches!(
655            auth_switch_request.auth_plugin(),
656            AuthPlugin::MysqlOldPassword
657        ) && self.0.opts.get_secure_auth()
658        {
659            return Err(DriverError(OldMysqlPasswordDisabled));
660        }
661
662        if matches!(
663            auth_switch_request.auth_plugin(),
664            AuthPlugin::Other(Cow::Borrowed(b"mysql_clear_password"))
665        ) && !self.0.opts.get_enable_cleartext_plugin()
666        {
667            return Err(DriverError(CleartextPluginDisabled));
668        }
669
670        self.0.nonce = auth_switch_request.plugin_data().to_vec();
671        self.0.auth_plugin = auth_switch_request.auth_plugin().into_owned();
672        let plugin_data = match self.0.auth_plugin {
673            ref x @ AuthPlugin::MysqlOldPassword => {
674                if self.0.opts.get_secure_auth() {
675                    return Err(DriverError(OldMysqlPasswordDisabled));
676                }
677                x.gen_data(self.0.opts.get_pass(), &self.0.nonce)
678            }
679            ref x @ AuthPlugin::MysqlNativePassword => {
680                x.gen_data(self.0.opts.get_pass(), &self.0.nonce)
681            }
682            ref x @ AuthPlugin::CachingSha2Password => {
683                x.gen_data(self.0.opts.get_pass(), &self.0.nonce)
684            }
685            ref x @ AuthPlugin::MysqlClearPassword => {
686                if !self.0.opts.get_enable_cleartext_plugin() {
687                    return Err(DriverError(UnknownAuthPlugin(
688                        "mysql_clear_password".into(),
689                    )));
690                }
691
692                x.gen_data(self.0.opts.get_pass(), &self.0.nonce)
693            }
694            ref x @ AuthPlugin::Ed25519 => x.gen_data(self.0.opts.get_pass(), &self.0.nonce),
695            AuthPlugin::Other(_) => None,
696        };
697
698        if let Some(plugin_data) = plugin_data {
699            self.write_struct(&plugin_data.into_owned())?;
700        } else {
701            self.write_packet(&mut &[0_u8; 0][..])?;
702        }
703
704        self.continue_auth(true)
705    }
706
707    fn do_handshake(&mut self) -> Result<()> {
708        let payload = self.read_packet()?;
709        let handshake = ParseBuf(&payload).parse::<HandshakePacket>(())?;
710
711        if handshake.protocol_version() != 10u8 {
712            return Err(DriverError(UnsupportedProtocol(
713                handshake.protocol_version(),
714            )));
715        }
716
717        if !handshake
718            .capabilities()
719            .contains(CapabilityFlags::CLIENT_PROTOCOL_41)
720        {
721            return Err(DriverError(Protocol41NotSet));
722        }
723
724        self.handle_handshake(&handshake);
725
726        if self.is_insecure() {
727            if let Some(ssl_opts) = self.0.opts.get_ssl_opts().cloned() {
728                if !self.has_capability(CapabilityFlags::CLIENT_SSL) {
729                    return Err(DriverError(TlsNotSupported));
730                } else {
731                    self.do_ssl_request()?;
732                    self.switch_to_ssl(ssl_opts)?;
733                }
734            }
735        }
736
737        // Handshake scramble is always 21 bytes length (20 + zero terminator)
738        self.0.nonce = {
739            let mut nonce = Vec::from(handshake.scramble_1_ref());
740            nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
741            // Trim zero terminator. Fill with zeroes if nonce
742            // is somehow smaller than 20 bytes (this matches the server behavior).
743            nonce.resize(20, 0);
744            nonce
745        };
746
747        // Allow only CachingSha2Password and MysqlNativePassword here
748        // because sha256_password is deprecated and other plugins won't
749        // appear here.
750        self.0.auth_plugin = match handshake.auth_plugin() {
751            Some(x @ AuthPlugin::CachingSha2Password) => x.into_owned(),
752            _ => AuthPlugin::MysqlNativePassword,
753        };
754
755        self.write_handshake_response()?;
756        self.continue_auth(false)?;
757
758        if self.has_capability(CapabilityFlags::CLIENT_COMPRESS) {
759            self.switch_to_compressed();
760        }
761
762        Ok(())
763    }
764
765    fn switch_to_compressed(&mut self) {
766        self.stream_mut()
767            .codec_mut()
768            .compress(Compression::default());
769    }
770
771    fn get_client_flags(&self) -> CapabilityFlags {
772        let mut client_flags = CapabilityFlags::CLIENT_PROTOCOL_41
773            | CapabilityFlags::CLIENT_SECURE_CONNECTION
774            | CapabilityFlags::CLIENT_LONG_PASSWORD
775            | CapabilityFlags::CLIENT_TRANSACTIONS
776            | CapabilityFlags::CLIENT_LOCAL_FILES
777            | CapabilityFlags::CLIENT_MULTI_STATEMENTS
778            | CapabilityFlags::CLIENT_MULTI_RESULTS
779            | CapabilityFlags::CLIENT_PS_MULTI_RESULTS
780            | CapabilityFlags::CLIENT_PLUGIN_AUTH
781            | (self.0.capability_flags & CapabilityFlags::CLIENT_LONG_FLAG);
782        if self.0.opts.get_compress().is_some() {
783            client_flags.insert(CapabilityFlags::CLIENT_COMPRESS);
784        }
785        if self.0.opts.get_connect_attrs().is_some() {
786            client_flags.insert(CapabilityFlags::CLIENT_CONNECT_ATTRS);
787        }
788        if let Some(db_name) = self.0.opts.get_db_name() {
789            if !db_name.is_empty() {
790                client_flags.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
791            }
792        }
793        if self.is_insecure() && self.0.opts.get_ssl_opts().is_some() {
794            client_flags.insert(CapabilityFlags::CLIENT_SSL);
795        }
796        client_flags | self.0.opts.get_additional_capabilities()
797    }
798
799    fn get_mariadb_client_flags(&self) -> MariadbCapabilities {
800        MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
801    }
802
803    fn connect_attrs(&self) -> Option<HashMap<String, String>> {
804        if let Some(attrs) = self.0.opts.get_connect_attrs() {
805            let program_name = match attrs.get("program_name") {
806                Some(program_name) => program_name.clone(),
807                None => {
808                    let arg0 = std::env::args_os().next();
809                    let arg0 = arg0.as_ref().map(|x| x.to_string_lossy());
810                    arg0.unwrap_or_else(|| "".into()).into_owned()
811                }
812            };
813
814            let mut attrs_to_send = HashMap::new();
815
816            attrs_to_send.insert("_client_name".into(), "rust-mysql-simple".into());
817            attrs_to_send.insert("_client_version".into(), env!("CARGO_PKG_VERSION").into());
818            attrs_to_send.insert("_os".into(), env!("CARGO_CFG_TARGET_OS").into());
819            attrs_to_send.insert("_pid".into(), process::id().to_string());
820            attrs_to_send.insert("_platform".into(), env!("CARGO_CFG_TARGET_ARCH").into());
821            attrs_to_send.insert("program_name".into(), program_name);
822
823            for (name, value) in attrs.clone() {
824                attrs_to_send.insert(name, value);
825            }
826
827            Some(attrs_to_send)
828        } else {
829            None
830        }
831    }
832
833    fn do_ssl_request(&mut self) -> Result<()> {
834        let charset = if self.server_version() >= (5, 5, 3) {
835            UTF8MB4_GENERAL_CI
836        } else {
837            UTF8_GENERAL_CI
838        };
839
840        let ssl_request = SslRequest::new(
841            self.get_client_flags(),
842            DEFAULT_MAX_ALLOWED_PACKET as u32,
843            charset as u8,
844        );
845        self.write_struct(&ssl_request)
846    }
847
848    fn write_handshake_response(&mut self) -> Result<()> {
849        let auth_data = self
850            .0
851            .auth_plugin
852            .gen_data(self.0.opts.get_pass(), &self.0.nonce)
853            .map(|x| x.into_owned());
854
855        let handshake_response = HandshakeResponse::new(
856            auth_data.as_deref(),
857            self.0.server_version.unwrap_or((0, 0, 0)),
858            self.0.opts.get_user().map(str::as_bytes),
859            self.0.opts.get_db_name().map(str::as_bytes),
860            Some(self.0.auth_plugin.clone()),
861            self.0.capability_flags,
862            self.connect_attrs(),
863            self.0
864                .opts
865                .get_max_allowed_packet()
866                .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET) as u32,
867        )
868        .with_mariadb_ext_capabilities(self.0.mariadb_ext_capabilities);
869
870        let mut buf = get_buffer();
871        handshake_response.serialize(buf.as_mut());
872        self.write_packet(&mut &*buf)
873    }
874
875    fn continue_auth(&mut self, auth_switched: bool) -> Result<()> {
876        match self.0.auth_plugin {
877            AuthPlugin::CachingSha2Password => {
878                self.continue_caching_sha2_password_auth(auth_switched)?;
879                Ok(())
880            }
881            AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
882                self.continue_mysql_native_password_auth(auth_switched)?;
883                Ok(())
884            }
885            AuthPlugin::MysqlClearPassword => {
886                if !self.0.opts.get_enable_cleartext_plugin() {
887                    return Err(DriverError(CleartextPluginDisabled));
888                }
889                self.continue_mysql_native_password_auth(auth_switched)?;
890                Ok(())
891            }
892            AuthPlugin::Ed25519 => {
893                self.continue_ed25519_auth(auth_switched)?;
894                Ok(())
895            }
896            AuthPlugin::Other(ref name) => {
897                let plugin_name = String::from_utf8_lossy(name).into();
898                Err(DriverError(UnknownAuthPlugin(plugin_name)))
899            }
900        }
901    }
902
903    fn continue_mysql_native_password_auth(&mut self, auth_switched: bool) -> Result<()> {
904        let payload = self.read_packet()?;
905
906        match payload[0] {
907            // auth ok
908            0x00 => self.handle_ok::<CommonOkPacket>(&payload).map(drop),
909            // auth switch
910            0xfe if !auth_switched => {
911                let auth_switch = if payload.len() > 1 {
912                    ParseBuf(&payload).parse(())?
913                } else {
914                    let _ = ParseBuf(&payload).parse::<OldAuthSwitchRequest>(())?;
915                    // we'll map OldAuthSwitchRequest to an AuthSwitchRequest with mysql_old_password plugin.
916                    AuthSwitchRequest::new("mysql_old_password".as_bytes(), &*self.0.nonce)
917                        .into_owned()
918                };
919                self.perform_auth_switch(auth_switch)
920            }
921            _ => Err(DriverError(UnexpectedPacket)),
922        }
923    }
924
925    fn continue_caching_sha2_password_auth(&mut self, auth_switched: bool) -> Result<()> {
926        let payload = self.read_packet()?;
927
928        match payload[0] {
929            0x00 => {
930                // ok packet for empty password
931                Ok(())
932            }
933            0x01 => match payload[1] {
934                0x03 => {
935                    let payload = self.read_packet()?;
936                    self.handle_ok::<CommonOkPacket>(&payload).map(drop)
937                }
938                0x04 => {
939                    if !self.is_insecure() || self.is_socket() {
940                        let mut pass = self.0.opts.get_pass().map(Vec::from).unwrap_or_default();
941                        pass.push(0);
942                        self.write_packet(&mut pass.as_slice())?;
943                    } else {
944                        self.write_packet(&mut &[0x02][..])?;
945                        let payload = self.read_packet()?;
946                        let key = &payload[1..];
947                        let mut pass = self.0.opts.get_pass().map(Vec::from).unwrap_or_default();
948                        pass.push(0);
949                        for (i, c) in pass.iter_mut().enumerate() {
950                            *(c) ^= self.0.nonce[i % self.0.nonce.len()];
951                        }
952                        let encrypted_pass = crypto::encrypt(&pass, key);
953                        self.write_packet(&mut encrypted_pass.as_slice())?;
954                    }
955
956                    let payload = self.read_packet()?;
957                    self.handle_ok::<CommonOkPacket>(&payload).map(drop)
958                }
959                _ => Err(DriverError(UnexpectedPacket)),
960            },
961            0xfe if !auth_switched => {
962                let auth_switch_request = ParseBuf(&payload).parse(())?;
963                self.perform_auth_switch(auth_switch_request)
964            }
965            _ => Err(DriverError(UnexpectedPacket)),
966        }
967    }
968
969    fn continue_ed25519_auth(&mut self, auth_switched: bool) -> Result<()> {
970        let payload = self.read_packet()?;
971        match payload[0] {
972            // ok packet for empty password
973            0x00 => Ok(()),
974            0xfe if !auth_switched => {
975                let auth_switch_request = ParseBuf(&payload).parse(())?;
976                self.perform_auth_switch(auth_switch_request)
977            }
978            _ => Err(DriverError(UnexpectedPacket)),
979        }
980    }
981
982    fn reset_seq_id(&mut self) {
983        self.stream_mut().codec_mut().reset_seq_id();
984    }
985
986    fn sync_seq_id(&mut self) {
987        self.stream_mut().codec_mut().sync_seq_id();
988    }
989
990    fn write_command_raw<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
991        let mut buf = get_buffer();
992        cmd.serialize(buf.as_mut());
993        self.reset_seq_id();
994        debug_assert!(!buf.is_empty());
995        self.0.last_command = buf[0];
996        self.write_packet(&mut &*buf)
997    }
998
999    fn write_command(&mut self, cmd: Command, data: &[u8]) -> Result<()> {
1000        let mut buf = get_buffer();
1001        buf.as_mut().put_u8(cmd as u8);
1002        buf.as_mut().extend_from_slice(data);
1003
1004        self.reset_seq_id();
1005        self.0.last_command = buf[0];
1006        self.write_packet(&mut &*buf)
1007    }
1008
1009    fn send_long_data(&mut self, stmt_id: u32, params: &[Value]) -> Result<()> {
1010        for (i, value) in params.iter().enumerate() {
1011            if let Bytes(bytes) = value {
1012                let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6);
1013                let chunks = chunks.chain(if bytes.is_empty() {
1014                    Some(&[][..])
1015                } else {
1016                    None
1017                });
1018                for chunk in chunks {
1019                    let cmd = ComStmtSendLongData::new(stmt_id, i as u16, Cow::Borrowed(chunk));
1020                    self.write_command_raw(&cmd)?;
1021                }
1022            }
1023        }
1024
1025        Ok(())
1026    }
1027
1028    fn _execute(&mut self, stmt: &Statement, params: Params) -> Result<ResultSetInfo> {
1029        let exec_request = match &params {
1030            Params::Empty => {
1031                if stmt.num_params() != 0 {
1032                    return Err(DriverError(MismatchedStmtParams(stmt.num_params(), 0)));
1033                }
1034
1035                let (body, _) = ComStmtExecuteRequestBuilder::new(stmt.id()).build(&[]);
1036                body
1037            }
1038            Params::Positional(params) => {
1039                if stmt.num_params() != params.len() as u16 {
1040                    return Err(DriverError(MismatchedStmtParams(
1041                        stmt.num_params(),
1042                        params.len(),
1043                    )));
1044                }
1045
1046                let (body, as_long_data) =
1047                    ComStmtExecuteRequestBuilder::new(stmt.id()).build(params);
1048
1049                if as_long_data {
1050                    self.send_long_data(stmt.id(), params)?;
1051                }
1052
1053                body
1054            }
1055            Params::Named(_) => {
1056                if let Some(named_params) = stmt.named_params.as_ref() {
1057                    return self._execute(stmt, params.into_positional(named_params)?);
1058                } else {
1059                    return Err(DriverError(NamedParamsForPositionalQuery));
1060                }
1061            }
1062        };
1063        self.write_command_raw(&exec_request)?;
1064        self.handle_result_set()
1065    }
1066
1067    fn _start_transaction(&mut self, tx_opts: TxOpts) -> Result<()> {
1068        if let Some(i_level) = tx_opts.isolation_level() {
1069            self.query_drop(format!("SET TRANSACTION ISOLATION LEVEL {}", i_level))?;
1070        }
1071        if let Some(mode) = tx_opts.access_mode() {
1072            let supported = match (self.0.server_version, self.0.mariadb_server_version) {
1073                (Some(ref version), _) if *version >= (5, 6, 5) => true,
1074                (_, Some(ref version)) if *version >= (10, 0, 0) => true,
1075                _ => false,
1076            };
1077            if !supported {
1078                return Err(DriverError(ReadOnlyTransNotSupported));
1079            }
1080            match mode {
1081                AccessMode::ReadOnly => self.query_drop("SET TRANSACTION READ ONLY")?,
1082                AccessMode::ReadWrite => self.query_drop("SET TRANSACTION READ WRITE")?,
1083            }
1084        }
1085        if tx_opts.with_consistent_snapshot() {
1086            self.query_drop("START TRANSACTION WITH CONSISTENT SNAPSHOT")
1087                .unwrap();
1088        } else {
1089            self.query_drop("START TRANSACTION")?;
1090        };
1091        Ok(())
1092    }
1093
1094    fn send_local_infile(&mut self, file_name: &[u8]) -> Result<OkPacket<'static>> {
1095        {
1096            let mut buffer = [0_u8; LocalInfile::BUFFER_SIZE];
1097            let maybe_handler = self
1098                .0
1099                .local_infile_handler
1100                .clone()
1101                .or_else(|| self.0.opts.get_local_infile_handler().cloned());
1102            let mut local_infile = LocalInfile::new(&mut buffer, self);
1103            if let Some(handler) = maybe_handler {
1104                // Unwrap won't panic because we have exclusive access to `self` and this
1105                // method is not re-entrant, because `LocalInfile` does not expose the
1106                // connection.
1107                let handler_fn = &mut *handler.0.lock()?;
1108                handler_fn(file_name, &mut local_infile)?;
1109            }
1110            local_infile.flush()?;
1111        }
1112        self.write_packet(&mut &[][..])?;
1113        let payload = self.read_packet()?;
1114        let ok = self.handle_ok::<CommonOkPacket>(&payload)?;
1115        Ok(ok.into_owned())
1116    }
1117
1118    fn handle_result_set(&mut self) -> Result<ResultSetInfo> {
1119        if self.more_results_exists() {
1120            self.sync_seq_id();
1121        }
1122
1123        let pld = self.read_packet()?;
1124        match pld[0] {
1125            0x00 => {
1126                let ok = self.handle_ok::<CommonOkPacket>(&pld)?;
1127                Ok(ResultSetInfo::Empty(ok.into_owned()))
1128            }
1129            0xfb => match self.send_local_infile(&pld[1..]) {
1130                Ok(ok) => Ok(ResultSetInfo::Empty(ok)),
1131                Err(err) => Err(err),
1132            },
1133            _ => {
1134                let mut reader = &pld[..];
1135                let column_count = reader.read_lenenc_int()?;
1136
1137                let mut columns: Vec<Column> = Vec::new();
1138
1139                // https://jira.mariadb.org/browse/MDEV-19237
1140                let output = if !(self
1141                    .has_mariadb_capability(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA)
1142                    && reader.first().copied() == Some(0x00))
1143                {
1144                    columns.reserve(column_count as usize);
1145                    for _ in 0..column_count {
1146                        let pld = self.read_packet()?;
1147                        let column = ParseBuf(&pld).parse(())?;
1148                        columns.push(column);
1149                    }
1150
1151                    ResultSetInfo::NonEmptyWithMeta(columns)
1152                } else {
1153                    ResultSetInfo::NonEmptySkipMeta
1154                };
1155
1156                // skip eof packet
1157                self.drop_packet()?;
1158                self.0.has_results = column_count > 0;
1159
1160                Ok(output)
1161            }
1162        }
1163    }
1164
1165    fn _query(&mut self, query: &str) -> Result<ResultSetMeta> {
1166        self.write_command(Command::COM_QUERY, query.as_bytes())?;
1167        let info = self.handle_result_set()?;
1168        let meta = info.into_query_meta();
1169        Ok(meta)
1170    }
1171
1172    /// Executes [`COM_PING`](https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_ping.html)
1173    /// on `Conn`. Return `true` on success or `false` on error.
1174    pub fn ping(&mut self) -> Result<(), Error> {
1175        self.write_command(Command::COM_PING, &[])?;
1176        self.drop_packet()
1177    }
1178
1179    /// Executes [`COM_INIT_DB`](https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_init_db.html)
1180    /// on `Conn`.
1181    pub fn select_db(&mut self, schema: &str) -> Result<(), Error> {
1182        self.write_command(Command::COM_INIT_DB, schema.as_bytes())?;
1183        self.drop_packet()
1184    }
1185
1186    /// Starts new transaction with provided options.
1187    /// `readonly` is only available since MySQL 5.6.5.
1188    pub fn start_transaction(&mut self, tx_opts: TxOpts) -> Result<Transaction> {
1189        self._start_transaction(tx_opts)?;
1190        Ok(Transaction::new(self.into()))
1191    }
1192
1193    fn _true_prepare(&mut self, query: &[u8]) -> Result<InnerStmt> {
1194        self.write_command(Command::COM_STMT_PREPARE, query)?;
1195        let pld = self.read_packet()?;
1196        let mut stmt = ParseBuf(&pld).parse::<InnerStmt>(self.connection_id())?;
1197        if stmt.num_params() > 0 {
1198            let mut params: Vec<Column> = Vec::with_capacity(stmt.num_params() as usize);
1199            for _ in 0..stmt.num_params() {
1200                let pld = self.read_packet()?;
1201                params.push(ParseBuf(&pld).parse(())?);
1202            }
1203            stmt = stmt.with_params(Some(params));
1204            self.drop_packet()?;
1205        }
1206        if stmt.num_columns() > 0 {
1207            let mut columns: Vec<Column> = Vec::with_capacity(stmt.num_columns() as usize);
1208            for _ in 0..stmt.num_columns() {
1209                let pld = self.read_packet()?;
1210                columns.push(ParseBuf(&pld).parse(())?);
1211            }
1212            stmt = stmt.with_columns(Some(columns));
1213            self.drop_packet()?;
1214        }
1215        Ok(stmt)
1216    }
1217
1218    fn _prepare(&mut self, query: &[u8]) -> Result<Arc<InnerStmt>> {
1219        if let Some(entry) = self.0.stmt_cache.by_query(query) {
1220            return Ok(entry.stmt.clone());
1221        }
1222
1223        let inner_st = Arc::new(self._true_prepare(query)?);
1224
1225        if let Some(old_stmt) = self
1226            .0
1227            .stmt_cache
1228            .put(Arc::new(query.into()), inner_st.clone())
1229        {
1230            self.close(Statement::new(old_stmt, None))?;
1231        }
1232
1233        Ok(inner_st)
1234    }
1235
1236    fn connect(&mut self) -> Result<()> {
1237        if self.0.connected {
1238            return Ok(());
1239        }
1240        self.do_handshake()
1241            .and_then(|_| match self.0.opts.get_max_allowed_packet() {
1242                Some(x) => Ok(x),
1243                None => Ok(from_value_opt::<usize>(
1244                    self.get_system_var("max_allowed_packet")?.unwrap_or(NULL),
1245                )
1246                .unwrap_or(0)),
1247            })
1248            .and_then(|max_allowed_packet| {
1249                if max_allowed_packet == 0 {
1250                    Err(DriverError(SetupError))
1251                } else {
1252                    self.stream_mut().codec_mut().max_allowed_packet = max_allowed_packet;
1253                    self.0.connected = true;
1254                    Ok(())
1255                }
1256            })
1257    }
1258
1259    fn get_system_var(&mut self, name: &str) -> Result<Option<Value>> {
1260        self.query_first(format!("SELECT @@{}", name))
1261    }
1262
1263    fn next_row_packet(&mut self) -> Result<Option<Buffer>> {
1264        if !self.0.has_results {
1265            return Ok(None);
1266        }
1267
1268        let pld = self.read_packet()?;
1269
1270        if self.has_capability(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
1271            if pld[0] == 0xfe && pld.len() < MAX_PAYLOAD_LEN {
1272                self.0.has_results = false;
1273                self.handle_ok::<ResultSetTerminator>(&pld)?;
1274                return Ok(None);
1275            }
1276        } else if pld[0] == 0xfe && pld.len() < 8 {
1277            self.0.has_results = false;
1278            self.handle_ok::<OldEofPacket>(&pld)?;
1279            return Ok(None);
1280        }
1281
1282        Ok(Some(pld))
1283    }
1284
1285    fn has_stmt(&self, query: &[u8]) -> bool {
1286        self.0.stmt_cache.contains_query(query)
1287    }
1288
1289    /// Sets a callback to handle requests for local files. These are
1290    /// caused by using `LOAD DATA LOCAL INFILE` queries. The
1291    /// callback is passed the filename, and a `Write`able object
1292    /// to receive the contents of that file.
1293    /// Specifying `None` will reset the handler to the one specified
1294    /// in the `Opts` for this connection.
1295    pub fn set_local_infile_handler(&mut self, handler: Option<LocalInfileHandler>) {
1296        self.0.local_infile_handler = handler;
1297    }
1298
1299    pub fn no_backslash_escape(&self) -> bool {
1300        self.0
1301            .status_flags
1302            .contains(StatusFlags::SERVER_STATUS_NO_BACKSLASH_ESCAPES)
1303    }
1304
1305    #[cfg(feature = "binlog")]
1306    fn register_as_slave(&mut self, server_id: u32) -> Result<()> {
1307        use mysql_common::packets::ComRegisterSlave;
1308
1309        self.query_drop("SET @master_binlog_checksum='ALL'")?;
1310        self.write_command_raw(&ComRegisterSlave::new(server_id))?;
1311
1312        // Server will respond with OK.
1313        self.read_packet()?;
1314
1315        Ok(())
1316    }
1317
1318    #[cfg(feature = "binlog")]
1319    fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> {
1320        self.register_as_slave(request.server_id())?;
1321        self.write_command_raw(&request.as_cmd())?;
1322        Ok(())
1323    }
1324
1325    /// Turns this connection into a binlog stream.
1326    ///
1327    /// You can use `SHOW BINARY LOGS` to get the current log file and position from the master.
1328    /// If the request's `filename` is empty, the server will send the binlog-stream
1329    /// of the first known binlog.
1330    #[cfg(feature = "binlog")]
1331    #[cfg_attr(docsrs, doc(cfg(feature = "binlog")))]
1332    pub fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result<BinlogStream> {
1333        self.request_binlog(request)?;
1334        Ok(BinlogStream::new(self))
1335    }
1336
1337    fn cleanup_for_pool(&mut self) -> Result<()> {
1338        self.set_local_infile_handler(None);
1339        if self.0.reset_upon_return {
1340            self.reset()?;
1341        }
1342
1343        self.0.reset_upon_return = self.0.opts.get_pool_opts().reset_connection();
1344
1345        Ok(())
1346    }
1347}
1348
1349#[cfg(unix)]
1350impl AsRawFd for Conn {
1351    fn as_raw_fd(&self) -> RawFd {
1352        self.stream_ref().get_ref().as_raw_fd()
1353    }
1354}
1355
1356impl Queryable for Conn {
1357    fn query_iter<T: AsRef<str>>(&mut self, query: T) -> Result<QueryResult<'_, '_, '_, Text>> {
1358        let meta = self._query(query.as_ref())?;
1359        Ok(QueryResult::new(ConnMut::Mut(self), meta))
1360    }
1361
1362    fn prep<T: AsRef<str>>(&mut self, query: T) -> Result<Statement> {
1363        let query = query.as_ref();
1364        let parsed = ParsedNamedParams::parse(query.as_bytes())?;
1365        let named_params: Vec<Vec<u8>> =
1366            parsed.params().iter().map(|param| param.to_vec()).collect();
1367        let named_params = if named_params.is_empty() {
1368            None
1369        } else {
1370            Some(named_params)
1371        };
1372        self._prepare(parsed.borrow().query())
1373            .map(|inner| Statement::new(inner, named_params))
1374    }
1375
1376    fn close(&mut self, stmt: Statement) -> Result<()> {
1377        self.0.stmt_cache.remove(stmt.id());
1378        let cmd = ComStmtClose::new(stmt.id());
1379        self.write_command_raw(&cmd)
1380    }
1381
1382    fn exec_iter<S, P>(&mut self, stmt: S, params: P) -> Result<QueryResult<'_, '_, '_, Binary>>
1383    where
1384        S: AsStatement,
1385        P: Into<Params>,
1386    {
1387        let statement = stmt.as_statement(self)?;
1388        let info = self._execute(&statement, params.into())?;
1389        let meta = info.into_statement_meta(&*self, &statement);
1390        Ok(QueryResult::new(ConnMut::Mut(self), meta))
1391    }
1392}
1393
1394impl Drop for Conn {
1395    fn drop(&mut self) {
1396        let stmt_cache = mem::replace(&mut self.0.stmt_cache, StmtCache::new(0));
1397
1398        for (_, entry) in stmt_cache.into_iter() {
1399            let _ = self.close(Statement::new(entry.stmt, None));
1400        }
1401
1402        if self.0.stream.is_some() {
1403            let _ = self.write_command(Command::COM_QUIT, &[]);
1404        }
1405    }
1406}
1407
1408#[cfg(test)]
1409#[allow(non_snake_case)]
1410mod test {
1411    mod my_conn {
1412        use std::{
1413            collections::HashMap,
1414            io::Write,
1415            process,
1416            sync::mpsc::{channel, sync_channel},
1417            thread::spawn,
1418            time::Duration,
1419        };
1420
1421        #[cfg(feature = "binlog")]
1422        use mysql_common::{binlog::events::EventData, packets::binlog_request::BinlogRequest};
1423        use rand::Fill;
1424        #[cfg(feature = "time")]
1425        use time::PrimitiveDateTime;
1426
1427        use crate::{
1428            conn::ConnInner,
1429            from_row, from_value, params,
1430            prelude::*,
1431            test_misc::get_opts,
1432            Conn,
1433            DriverError::{MissingNamedParameter, NamedParamsForPositionalQuery},
1434            Error::DriverError,
1435            LocalInfileHandler, Opts, OptsBuilder, Pool, TxOpts,
1436            Value::{self, Bytes, Date, Float, Int, NULL},
1437        };
1438
1439        fn get_system_variable<T>(conn: &mut Conn, name: &str) -> T
1440        where
1441            T: FromValue,
1442        {
1443            conn.query_first::<(String, T), _>(format!("show variables like '{}'", name))
1444                .unwrap()
1445                .unwrap()
1446                .1
1447        }
1448
1449        #[test]
1450        fn should_connect() {
1451            let mut conn = Conn::new(get_opts()).unwrap();
1452
1453            let mode: String = conn
1454                .query_first("SELECT @@GLOBAL.sql_mode")
1455                .unwrap()
1456                .unwrap();
1457            assert!(mode.contains("TRADITIONAL"));
1458            assert!(conn.ping().is_ok());
1459
1460            if crate::test_misc::test_compression() {
1461                assert!(format!("{:?}", conn.0.stream).contains("Compression"));
1462            }
1463
1464            if crate::test_misc::test_ssl() {
1465                assert!(!conn.is_insecure());
1466            }
1467        }
1468
1469        #[test]
1470        fn mysql_async_issue_107() -> crate::Result<()> {
1471            let mut conn = Conn::new(get_opts())?;
1472            conn.query_drop(
1473                r"CREATE TEMPORARY TABLE mysql.issue (
1474                        a BIGINT(20) UNSIGNED,
1475                        b VARBINARY(16),
1476                        c BINARY(32),
1477                        d BIGINT(20) UNSIGNED,
1478                        e BINARY(32)
1479                    )",
1480            )?;
1481            conn.query_drop(
1482                r"INSERT INTO mysql.issue VALUES (
1483                        0,
1484                        0xC066F966B0860000,
1485                        0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1486                        0,
1487                        ''
1488                    ), (
1489                        1,
1490                        '',
1491                        0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1492                        0,
1493                        ''
1494                    )",
1495            )?;
1496
1497            let q = "SELECT b, c, d, e FROM mysql.issue";
1498            let result = conn.query_iter(q)?;
1499
1500            let loaded_structs = result
1501                .map(|row| crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>(row.unwrap()))
1502                .collect::<Vec<_>>();
1503
1504            assert_eq!(loaded_structs.len(), 2);
1505
1506            Ok(())
1507        }
1508
1509        #[test]
1510        fn query_traits() -> Result<(), Box<dyn std::error::Error>> {
1511            macro_rules! test_query {
1512                ($conn : expr) => {
1513                    "CREATE TABLE IF NOT EXISTS tmplak (a INT)"
1514                        .run($conn)
1515                        .unwrap();
1516                    "DELETE FROM tmplak".run($conn).unwrap();
1517
1518                    "INSERT INTO tmplak (a) VALUES (?)"
1519                        .with((42,))
1520                        .run($conn)
1521                        .unwrap();
1522
1523                    "INSERT INTO tmplak (a) VALUES (?)"
1524                        .with((43..=44).map(|x| (x,)))
1525                        .batch($conn)?;
1526
1527                    let first: Option<u8> = "SELECT a FROM tmplak LIMIT 1".first($conn).unwrap();
1528                    assert_eq!(first, Some(42), "first text");
1529
1530                    let first: Option<u8> = "SELECT a FROM tmplak LIMIT 1"
1531                        .with(())
1532                        .first($conn)
1533                        .unwrap();
1534                    assert_eq!(first, Some(42), "first bin");
1535
1536                    let count = "SELECT a FROM tmplak".run($conn).unwrap().count();
1537                    assert_eq!(count, 3, "run text");
1538
1539                    let count = "SELECT a FROM tmplak".with(()).run($conn).unwrap().count();
1540                    assert_eq!(count, 3, "run bin");
1541
1542                    let all: Vec<u8> = "SELECT a FROM tmplak".fetch($conn).unwrap();
1543                    assert_eq!(all, vec![42, 43, 44], "fetch text");
1544
1545                    let all: Vec<u8> = "SELECT a FROM tmplak".with(()).fetch($conn).unwrap();
1546                    assert_eq!(all, vec![42, 43, 44], "fetch bin");
1547
1548                    let mapped = "SELECT a FROM tmplak".map($conn, |x: u8| x + 1).unwrap();
1549                    assert_eq!(mapped, vec![43, 44, 45], "map text");
1550
1551                    let mapped = "SELECT a FROM tmplak"
1552                        .with(())
1553                        .map($conn, |x: u8| x + 1)
1554                        .unwrap();
1555                    assert_eq!(mapped, vec![43, 44, 45], "map bin");
1556
1557                    let sum = "SELECT a FROM tmplak"
1558                        .fold($conn, 0_u8, |acc, x: u8| acc + x)
1559                        .unwrap();
1560                    assert_eq!(sum, 42 + 43 + 44, "fold text");
1561
1562                    let sum = "SELECT a FROM tmplak"
1563                        .with(())
1564                        .fold($conn, 0_u8, |acc, x: u8| acc + x)
1565                        .unwrap();
1566                    assert_eq!(sum, 42 + 43 + 44, "fold bin");
1567
1568                    "DROP TABLE tmplak".run($conn).unwrap();
1569                };
1570            }
1571
1572            let mut conn = Conn::new(get_opts())?;
1573
1574            let mut tx = conn.start_transaction(TxOpts::default())?;
1575            test_query!(&mut tx);
1576            tx.rollback()?;
1577
1578            test_query!(&mut conn);
1579
1580            let pool = Pool::new(get_opts())?;
1581            let mut pooled_conn = pool.get_conn()?;
1582
1583            let mut tx = pool.start_transaction(TxOpts::default())?;
1584            test_query!(&mut tx);
1585            tx.rollback()?;
1586
1587            test_query!(&mut pooled_conn);
1588
1589            Ok(())
1590        }
1591
1592        #[test]
1593        #[should_panic(expected = "Could not connect to address")]
1594        fn should_fail_on_wrong_socket_path() {
1595            let opts = OptsBuilder::from_opts(get_opts()).socket(Some("/foo/bar/baz"));
1596            let _ = Conn::new(opts).unwrap();
1597        }
1598
1599        #[test]
1600        fn should_fallback_to_tcp_if_cant_switch_to_socket() {
1601            let mut opts = Opts::from(get_opts());
1602            opts.0.injected_socket = Some("/foo/bar/baz".into());
1603            let _ = Conn::new(opts).unwrap();
1604        }
1605
1606        #[test]
1607        fn should_connect_with_database() {
1608            const DB_NAME: &str = "mysql";
1609
1610            let opts = OptsBuilder::from_opts(get_opts()).db_name(Some(DB_NAME));
1611
1612            let mut conn = Conn::new(opts).unwrap();
1613
1614            let db_name: String = conn.query_first("SELECT DATABASE()").unwrap().unwrap();
1615            assert_eq!(db_name, DB_NAME);
1616        }
1617
1618        #[test]
1619        fn should_connect_by_hostname() {
1620            let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost"));
1621            let mut conn = Conn::new(opts).unwrap();
1622            assert!(conn.ping().is_ok());
1623        }
1624
1625        #[test]
1626        fn should_select_db() {
1627            const DB_NAME: &str = "t_select_db";
1628
1629            let mut conn = Conn::new(get_opts()).unwrap();
1630            conn.query_drop(format!("CREATE DATABASE IF NOT EXISTS {}", DB_NAME))
1631                .unwrap();
1632            assert!(conn.select_db(DB_NAME).is_ok());
1633
1634            let db_name: String = conn.query_first("SELECT DATABASE()").unwrap().unwrap();
1635            assert_eq!(db_name, DB_NAME);
1636
1637            conn.query_drop(format!("DROP DATABASE {}", DB_NAME))
1638                .unwrap();
1639        }
1640
1641        #[test]
1642        fn should_execute_queries_and_parse_results() {
1643            type TestRow = (String, String, String, String, String, String);
1644
1645            const CREATE_QUERY: &str = r"CREATE TEMPORARY TABLE mysql.tbl
1646                (id SERIAL, a TEXT, b INT, c INT UNSIGNED, d DATE, e FLOAT)";
1647            const INSERT_QUERY_1: &str = r"INSERT
1648                INTO mysql.tbl(a, b, c, d, e)
1649                VALUES ('hello', -123, 123, '2014-05-05', 123.123)";
1650            const INSERT_QUERY_2: &str = r"INSERT
1651                INTO mysql.tbl(a, b, c, d, e)
1652                VALUES ('world', -321, 321, '2014-06-06', 321.321)";
1653
1654            let mut conn = Conn::new(get_opts()).unwrap();
1655
1656            conn.query_drop(CREATE_QUERY).unwrap();
1657            assert_eq!(conn.affected_rows(), 0);
1658            assert_eq!(conn.last_insert_id(), 0);
1659
1660            conn.query_drop(INSERT_QUERY_1).unwrap();
1661            assert_eq!(conn.affected_rows(), 1);
1662            assert_eq!(conn.last_insert_id(), 1);
1663
1664            conn.query_drop(INSERT_QUERY_2).unwrap();
1665            assert_eq!(conn.affected_rows(), 1);
1666            assert_eq!(conn.last_insert_id(), 2);
1667
1668            conn.query_drop("SELECT * FROM nonexistent").unwrap_err();
1669            conn.query_iter("SELECT * FROM mysql.tbl").unwrap(); // Drop::drop for QueryResult
1670
1671            conn.query_drop("UPDATE mysql.tbl SET a = 'foo'").unwrap();
1672            assert_eq!(conn.affected_rows(), 2);
1673            assert_eq!(conn.last_insert_id(), 0);
1674
1675            assert!(conn
1676                .query_first::<TestRow, _>("SELECT * FROM mysql.tbl WHERE a = 'bar'")
1677                .unwrap()
1678                .is_none());
1679
1680            let rows: Vec<TestRow> = conn.query("SELECT * FROM mysql.tbl").unwrap();
1681            assert_eq!(
1682                rows,
1683                vec![
1684                    (
1685                        "1".into(),
1686                        "foo".into(),
1687                        "-123".into(),
1688                        "123".into(),
1689                        "2014-05-05".into(),
1690                        "123.123".into()
1691                    ),
1692                    (
1693                        "2".into(),
1694                        "foo".into(),
1695                        "-321".into(),
1696                        "321".into(),
1697                        "2014-06-06".into(),
1698                        "321.321".into()
1699                    )
1700                ]
1701            );
1702        }
1703
1704        #[test]
1705        fn should_parse_large_text_result() {
1706            let mut conn = Conn::new(get_opts()).unwrap();
1707            let value: Value = conn
1708                .query_first("SELECT REPEAT('A', 20000000)")
1709                .unwrap()
1710                .unwrap();
1711            assert_eq!(
1712                value,
1713                Bytes(std::iter::repeat_n(b'A', 20_000_000).collect())
1714            );
1715        }
1716
1717        #[test]
1718        fn should_execute_statements_and_parse_results() {
1719            const CREATE_QUERY: &str = r"CREATE TEMPORARY TABLE
1720                mysql.tbl (a TEXT, b INT, c INT UNSIGNED, d DATE, e FLOAT)";
1721            const INSERT_STMT: &str = r"INSERT
1722                INTO mysql.tbl (a, b, c, d, e)
1723                VALUES (?, ?, ?, ?, ?)";
1724
1725            type RowType = (Value, Value, Value, Value, Value);
1726
1727            let row1 = (
1728                Bytes(b"hello".to_vec()),
1729                Int(-123_i64),
1730                Int(123_i64),
1731                Date(2014_u16, 5_u8, 5_u8, 0_u8, 0_u8, 0_u8, 0_u32),
1732                Float(123.123_f32),
1733            );
1734            let row2 = (Bytes(b"".to_vec()), NULL, NULL, NULL, Float(321.321_f32));
1735
1736            let mut conn = Conn::new(get_opts()).unwrap();
1737            conn.query_drop(CREATE_QUERY).unwrap();
1738
1739            let insert_stmt = conn.prep(INSERT_STMT).unwrap();
1740            assert_eq!(insert_stmt.connection_id(), conn.connection_id());
1741            conn.exec_drop(
1742                &insert_stmt,
1743                (
1744                    from_value::<String>(row1.0.clone()),
1745                    from_value::<i32>(row1.1.clone()),
1746                    from_value::<u32>(row1.2.clone()),
1747                    from_value::<time::PrimitiveDateTime>(row1.3.clone()),
1748                    from_value::<f32>(row1.4.clone()),
1749                ),
1750            )
1751            .unwrap();
1752            conn.exec_drop(
1753                &insert_stmt,
1754                (
1755                    from_value::<String>(row2.0.clone()),
1756                    row2.1.clone(),
1757                    row2.2.clone(),
1758                    row2.3.clone(),
1759                    from_value::<f32>(row2.4.clone()),
1760                ),
1761            )
1762            .unwrap();
1763
1764            let select_stmt = conn.prep("SELECT * from mysql.tbl").unwrap();
1765            let rows: Vec<RowType> = conn.exec(&select_stmt, ()).unwrap();
1766
1767            assert_eq!(rows, vec![row1, row2]);
1768        }
1769
1770        #[test]
1771        fn should_parse_large_binary_result() {
1772            let mut conn = Conn::new(get_opts()).unwrap();
1773            let stmt = conn.prep("SELECT REPEAT('A', 20000000)").unwrap();
1774            let value: Value = conn.exec_first(&stmt, ()).unwrap().unwrap();
1775            assert_eq!(
1776                value,
1777                Bytes(std::iter::repeat_n(b'A', 20_000_000).collect())
1778            );
1779        }
1780
1781        #[test]
1782        fn manually_closed_stmt() {
1783            let opts = get_opts().stmt_cache_size(1);
1784            let mut conn = Conn::new(opts).unwrap();
1785            let stmt = conn.prep("SELECT 1").unwrap();
1786            conn.exec_drop(&stmt, ()).unwrap();
1787            conn.close(stmt).unwrap();
1788            let stmt = conn.prep("SELECT 1").unwrap();
1789            conn.exec_drop(&stmt, ()).unwrap();
1790            conn.close(stmt).unwrap();
1791            let stmt = conn.prep("SELECT 2").unwrap();
1792            conn.exec_drop(&stmt, ()).unwrap();
1793        }
1794
1795        #[test]
1796        fn should_start_commit_and_rollback_transactions() {
1797            let mut conn = Conn::new(get_opts()).unwrap();
1798            conn.query_drop(
1799                "CREATE TEMPORARY TABLE mysql.tbl(id INT NOT NULL PRIMARY KEY AUTO_INCREMENT, a INT)",
1800            )
1801            .unwrap();
1802            conn.start_transaction(TxOpts::default())
1803                .map(|mut t| {
1804                    t.query_drop("INSERT INTO mysql.tbl(a) VALUES(1)").unwrap();
1805                    assert_eq!(t.last_insert_id(), Some(1));
1806                    assert_eq!(t.affected_rows(), 1);
1807                    t.query_drop("INSERT INTO mysql.tbl(a) VALUES(2)").unwrap();
1808                    t.commit().unwrap();
1809                })
1810                .unwrap();
1811            assert_eq!(
1812                conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1813                    .unwrap()
1814                    .next()
1815                    .unwrap()
1816                    .unwrap()
1817                    .unwrap(),
1818                vec![Bytes(b"2".to_vec())]
1819            );
1820            conn.start_transaction(TxOpts::default())
1821                .map(|mut t| {
1822                    t.query_drop("INSERT INTO tbl2(a) VALUES(1)").unwrap_err();
1823                    // implicit rollback
1824                })
1825                .unwrap();
1826            assert_eq!(
1827                conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1828                    .unwrap()
1829                    .next()
1830                    .unwrap()
1831                    .unwrap()
1832                    .unwrap(),
1833                vec![Bytes(b"2".to_vec())]
1834            );
1835            conn.start_transaction(TxOpts::default())
1836                .map(|mut t| {
1837                    t.query_drop("INSERT INTO mysql.tbl(a) VALUES(1)").unwrap();
1838                    t.query_drop("INSERT INTO mysql.tbl(a) VALUES(2)").unwrap();
1839                    t.rollback().unwrap();
1840                })
1841                .unwrap();
1842            assert_eq!(
1843                conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1844                    .unwrap()
1845                    .next()
1846                    .unwrap()
1847                    .unwrap()
1848                    .unwrap(),
1849                vec![Bytes(b"2".to_vec())]
1850            );
1851            let mut tx = conn.start_transaction(TxOpts::default()).unwrap();
1852            tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (3,))
1853                .unwrap();
1854            tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (4,))
1855                .unwrap();
1856            tx.commit().unwrap();
1857            assert_eq!(
1858                conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1859                    .unwrap()
1860                    .next()
1861                    .unwrap()
1862                    .unwrap()
1863                    .unwrap(),
1864                vec![Bytes(b"4".to_vec())]
1865            );
1866            let mut tx = conn.start_transaction(TxOpts::default()).unwrap();
1867            tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (5,))
1868                .unwrap();
1869            tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (6,))
1870                .unwrap();
1871            drop(tx);
1872            assert_eq!(
1873                conn.query_first("SELECT COUNT(a) from mysql.tbl").unwrap(),
1874                Some(4_usize),
1875            );
1876        }
1877        #[test]
1878        fn should_handle_LOCAL_INFILE_with_custom_handler() {
1879            let mut conn = Conn::new(get_opts()).unwrap();
1880            conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a TEXT)")
1881                .unwrap();
1882            conn.set_local_infile_handler(Some(LocalInfileHandler::new(|_, stream| {
1883                let mut cell_data = vec![b'Z'; 65535];
1884                cell_data.push(b'\n');
1885                for _ in 0..1536 {
1886                    stream.write_all(&cell_data)?;
1887                }
1888                Ok(())
1889            })));
1890            match conn.query_drop("LOAD DATA LOCAL INFILE 'file_name' INTO TABLE mysql.tbl") {
1891                Ok(_) => {}
1892                Err(ref err) if format!("{}", err).find("not allowed").is_some() => {
1893                    return;
1894                }
1895                Err(err) => panic!("ERROR {}", err),
1896            }
1897            let count = conn
1898                .query_iter("SELECT * FROM mysql.tbl")
1899                .unwrap()
1900                .map(|row| {
1901                    assert_eq!(from_row::<(Vec<u8>,)>(row.unwrap()).0.len(), 65535);
1902                    1
1903                })
1904                .sum::<usize>();
1905            assert_eq!(count, 1536);
1906        }
1907
1908        #[test]
1909        fn should_reset_connection() {
1910            let mut conn = Conn::new(get_opts()).unwrap();
1911            conn.query_drop(
1912                "CREATE TEMPORARY TABLE `mysql`.`test` \
1913                 (`test` VARCHAR(255) NULL);",
1914            )
1915            .unwrap();
1916            conn.query_drop("INSERT INTO `mysql`.`test` (`test`) VALUES ('foo');")
1917                .unwrap();
1918            assert_eq!(conn.affected_rows(), 1);
1919            conn.reset().unwrap();
1920            assert_eq!(conn.affected_rows(), 0);
1921            conn.query_drop("SELECT * FROM `mysql`.`test`;")
1922                .unwrap_err();
1923        }
1924
1925        #[test]
1926        fn should_change_user() -> crate::Result<()> {
1927            /// Whether particular authentication plugin should be tested on the current database.
1928            type ShouldRunFn = fn(bool, (u16, u16, u16)) -> bool;
1929            /// Generates `CREATE USER` and `SET PASSWORD` statements
1930            type CreateUserFn = fn(bool, (u16, u16, u16), &str) -> Vec<String>;
1931
1932            #[allow(clippy::type_complexity)]
1933            const TEST_MATRIX: [(&str, ShouldRunFn, CreateUserFn); 4] = [
1934                (
1935                    "mysql_old_password",
1936                    |is_mariadb, version| is_mariadb || version < (5, 7, 0),
1937                    |is_mariadb, version, pass| {
1938                        if is_mariadb {
1939                            vec![
1940                                "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password"
1941                                    .into(),
1942                                "SET old_passwords=1".into(),
1943                                format!("ALTER USER '__mats'@'%' IDENTIFIED BY '{pass}'"),
1944                                "SET old_passwords=0".into(),
1945                            ]
1946                        } else if matches!(version, (5, 6, _)) {
1947                            vec![
1948                                "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password"
1949                                    .into(),
1950                                format!("SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD('{pass}')"),
1951                            ]
1952                        } else {
1953                            vec![
1954                                "CREATE USER '__mats'@'%'".into(),
1955                                format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"),
1956                            ]
1957                        }
1958                    },
1959                ),
1960                (
1961                    "mysql_native_password",
1962                    |is_mariadb, version| is_mariadb || version < (8, 4, 0),
1963                    |is_mariadb, version, pass| {
1964                        if is_mariadb {
1965                            vec![
1966                                format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password AS PASSWORD('{pass}')")
1967                            ]
1968                        } else if version < (8, 0, 0) {
1969                            vec![
1970                                "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password"
1971                                    .into(),
1972                                "SET old_passwords = 0".into(),
1973                                format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"),
1974                            ]
1975                        } else {
1976                            vec![
1977                                format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password BY '{pass}'")
1978                            ]
1979                        }
1980                    },
1981                ),
1982                (
1983                    "caching_sha2_password",
1984                    |is_mariadb, version| !is_mariadb && version >= (5, 8, 0),
1985                    |_is_mariadb, _version, pass| {
1986                        vec![
1987                            format!("CREATE USER '__mats'@'%' IDENTIFIED WITH caching_sha2_password BY '{pass}'")
1988                        ]
1989                    },
1990                ),
1991                (
1992                    "client_ed25519",
1993                    |is_mariadb, version| is_mariadb && version >= (10, 4, 0),
1994                    |_is_mariadb, _version, pass| {
1995                        vec![
1996                            format!("CREATE USER '__mats'@'%' IDENTIFIED WITH ed25519 AS PASSWORD('{pass}')")
1997                        ]
1998                    },
1999                ),
2000            ];
2001
2002            fn random_pass() -> String {
2003                let mut rng = rand::thread_rng();
2004                let mut pass = [0u8; 10];
2005                pass.try_fill(&mut rng).unwrap();
2006                IntoIterator::into_iter(pass)
2007                    .map(|x| ((x % (123 - 97)) + 97) as char)
2008                    .collect()
2009            }
2010
2011            let mut conn = Conn::new(get_opts()).unwrap();
2012
2013            assert_eq!(
2014                conn.query_first::<Value, _>("SELECT @foo")
2015                    .unwrap()
2016                    .unwrap(),
2017                Value::NULL
2018            );
2019
2020            conn.query_drop("SET @foo = 'foo'").unwrap();
2021
2022            assert_eq!(
2023                conn.query_first::<String, _>("SELECT @foo")
2024                    .unwrap()
2025                    .unwrap(),
2026                "foo",
2027            );
2028
2029            conn.change_user(Default::default()).unwrap();
2030            assert_eq!(
2031                conn.query_first::<Value, _>("SELECT @foo")
2032                    .unwrap()
2033                    .unwrap(),
2034                Value::NULL
2035            );
2036
2037            for (plugin, should_run, create_statements) in TEST_MATRIX {
2038                dbg!(plugin);
2039                let is_mariadb = conn.0.mariadb_server_version.is_some();
2040                let version = conn.server_version();
2041
2042                if should_run(is_mariadb, version) {
2043                    let pass = random_pass();
2044
2045                    // (M)!50700 IF EXISTS: 5.7.0 (also on MariaDB) is minimum version that sees this clause
2046                    let statement =
2047                        "DROP USER /*!50700 IF EXISTS */ /*M!50700 IF EXISTS */ '__mats'";
2048                    // No IF EXISTS before 5.7 so the query may fail otherwise
2049                    _ = conn.query_drop(dbg!(statement));
2050
2051                    for statement in create_statements(is_mariadb, version, &pass) {
2052                        conn.query_drop(dbg!(statement)).unwrap();
2053                    }
2054
2055                    let mut conn2 = Conn::new(get_opts().secure_auth(false)).unwrap();
2056                    conn2
2057                        .change_user(
2058                            crate::ChangeUserOpts::default()
2059                                .with_db_name(None)
2060                                .with_user(Some("__mats".into()))
2061                                .with_pass(Some(pass)),
2062                        )
2063                        .unwrap();
2064
2065                    let (db, user) = conn2
2066                        .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
2067                        .unwrap()
2068                        .unwrap();
2069                    assert_eq!(db, None);
2070                    assert!(user.starts_with("__mats"));
2071                }
2072            }
2073
2074            Ok(())
2075        }
2076
2077        #[test]
2078        fn prep_exec() {
2079            let mut conn = Conn::new(get_opts()).unwrap();
2080
2081            let stmt1 = conn.prep("SELECT :foo").unwrap();
2082            let stmt2 = conn.prep("SELECT :bar").unwrap();
2083            assert_eq!(
2084                conn.exec::<String, _, _>(&stmt1, params! { "foo" => "foo" })
2085                    .unwrap(),
2086                vec![String::from("foo")],
2087            );
2088            assert_eq!(
2089                conn.exec::<String, _, _>(&stmt2, params! { "bar" => "bar" })
2090                    .unwrap(),
2091                vec![String::from("bar")],
2092            );
2093        }
2094
2095        #[test]
2096        fn should_connect_via_socket_for_127_0_0_1() {
2097            let opts = OptsBuilder::from_opts(get_opts());
2098            let mut conn = Conn::new(opts).unwrap();
2099            if conn.is_insecure() {
2100                assert!(
2101                    conn.is_socket(),
2102                    "Did not reconnect via socket {:?}",
2103                    (
2104                        conn.0.opts.get_prefer_socket(),
2105                        conn.0.opts.addr_is_loopback(),
2106                        conn.can_improved().and_then(|opts| {
2107                            opts.map(|opts| {
2108                                let mut new = crate::conn::Conn(Box::new(ConnInner::empty(opts)));
2109                                new.connect_stream().and_then(|_| {
2110                                    new.connect()?;
2111                                    Ok(new)
2112                                })
2113                            })
2114                            .transpose()
2115                        }),
2116                    )
2117                );
2118            }
2119        }
2120
2121        #[test]
2122        fn should_connect_via_socket_localhost() {
2123            let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost"));
2124            let mut conn = Conn::new(opts).unwrap();
2125            if conn.is_insecure() {
2126                assert!(
2127                    conn.is_socket(),
2128                    "Did not reconnect via socket {:?}",
2129                    (
2130                        conn.0.opts.get_prefer_socket(),
2131                        conn.0.opts.addr_is_loopback(),
2132                        conn.can_improved().and_then(|opts| {
2133                            opts.map(|opts| {
2134                                let mut new = crate::conn::Conn(Box::new(ConnInner::empty(opts)));
2135                                new.connect_stream().and_then(|_| {
2136                                    new.connect()?;
2137                                    Ok(new)
2138                                })
2139                            })
2140                            .transpose()
2141                        }),
2142                    )
2143                );
2144            }
2145        }
2146
2147        /// QueryResult::drop hangs on connectivity errors (see [blackbeam/rust-mysql-simple#306][1]).
2148        ///
2149        /// [1]: https://github.com/blackbeam/rust-mysql-simple/issues/306
2150        #[test]
2151        fn issue_306() {
2152            let (tx, rx) = channel::<()>();
2153            let handle = spawn(move || {
2154                let mut c1 = Conn::new(get_opts()).unwrap();
2155                let c1_id = c1.connection_id();
2156                let mut c2 = Conn::new(get_opts()).unwrap();
2157                let query_result = c1.query_iter("DO 1; SELECT SLEEP(1); DO 2;").unwrap();
2158                c2.query_drop(format!("KILL {c1_id}")).unwrap();
2159                drop(c2);
2160                drop(query_result);
2161                tx.send(()).unwrap();
2162            });
2163            std::thread::sleep(Duration::from_secs(2));
2164            assert!(rx.try_recv().is_ok());
2165            handle.join().unwrap();
2166        }
2167
2168        #[test]
2169        fn reset_does_work() {
2170            let mut c = Conn::new(get_opts()).unwrap();
2171            let cid = c.connection_id();
2172            c.query_drop("SET @foo = 'foo'").unwrap();
2173            assert_eq!(
2174                c.query_first::<String, _>("SELECT @foo").unwrap().unwrap(),
2175                "foo",
2176            );
2177            c.reset().unwrap();
2178            assert_eq!(cid, c.connection_id());
2179            assert_eq!(
2180                c.query_first::<Value, _>("SELECT @foo").unwrap().unwrap(),
2181                Value::NULL
2182            );
2183        }
2184
2185        #[test]
2186        fn should_drop_multi_result_set() {
2187            let opts = OptsBuilder::from_opts(get_opts()).db_name(Some("mysql"));
2188            let mut conn = Conn::new(opts).unwrap();
2189            conn.query_drop("CREATE TEMPORARY TABLE TEST_TABLE ( name varchar(255) )")
2190                .unwrap();
2191            conn.exec_drop("SELECT * FROM TEST_TABLE", ()).unwrap();
2192            conn.query_drop(
2193                r"
2194                INSERT INTO TEST_TABLE (name) VALUES ('one');
2195                INSERT INTO TEST_TABLE (name) VALUES ('two');
2196                INSERT INTO TEST_TABLE (name) VALUES ('three');",
2197            )
2198            .unwrap();
2199            conn.exec_drop("SELECT * FROM TEST_TABLE", ()).unwrap();
2200
2201            let mut query_result = conn
2202                .query_iter(
2203                    r"
2204                SELECT * FROM TEST_TABLE;
2205                INSERT INTO TEST_TABLE (name) VALUES ('one');
2206                DO 0;",
2207                )
2208                .unwrap();
2209
2210            while let Some(result) = query_result.iter() {
2211                result.affected_rows();
2212            }
2213        }
2214
2215        #[test]
2216        fn should_handle_multi_result_set() {
2217            let opts = OptsBuilder::from_opts(get_opts())
2218                .prefer_socket(false)
2219                .db_name(Some("mysql"));
2220            let mut conn = Conn::new(opts).unwrap();
2221            conn.query_drop("DROP PROCEDURE IF EXISTS multi").unwrap();
2222            conn.query_drop(
2223                r#"CREATE PROCEDURE multi() BEGIN
2224                        SELECT 1 UNION ALL SELECT 2;
2225                        DO 1;
2226                        SELECT 3 UNION ALL SELECT 4;
2227                        DO 1;
2228                        DO 1;
2229                        SELECT REPEAT('A', 17000000);
2230                        SELECT REPEAT('A', 17000000);
2231                    END"#,
2232            )
2233            .unwrap();
2234            {
2235                let mut query_result = conn.query_iter("CALL multi()").unwrap();
2236                let result_set = query_result
2237                    .by_ref()
2238                    .map(|row| row.unwrap().unwrap().pop().unwrap())
2239                    .collect::<Vec<crate::Value>>();
2240                assert_eq!(result_set, vec![Bytes(b"1".to_vec()), Bytes(b"2".to_vec())]);
2241                let result_set = query_result
2242                    .by_ref()
2243                    .map(|row| row.unwrap().unwrap().pop().unwrap())
2244                    .collect::<Vec<crate::Value>>();
2245                assert_eq!(result_set, vec![Bytes(b"3".to_vec()), Bytes(b"4".to_vec())]);
2246            }
2247            let mut result = conn.query_iter("SELECT 1; SELECT 2; SELECT 3;").unwrap();
2248            let mut i = 0;
2249            while let Some(result_set) = result.iter() {
2250                i += 1;
2251                for row in result_set {
2252                    match i {
2253                        1 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"1".to_vec())]),
2254                        2 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"2".to_vec())]),
2255                        3 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"3".to_vec())]),
2256                        _ => unreachable!(),
2257                    }
2258                }
2259            }
2260            assert_eq!(i, 3);
2261        }
2262
2263        #[test]
2264        fn issue_273() {
2265            let opts = OptsBuilder::from_opts(get_opts()).prefer_socket(false);
2266            let mut conn = Conn::new(opts).unwrap();
2267
2268            "DROP FUNCTION IF EXISTS f1".run(&mut conn).unwrap();
2269            r"CREATE DEFINER=`root`@`localhost` FUNCTION `f1`(p_arg INT, p_arg2 INT) RETURNS int
2270            DETERMINISTIC
2271            BEGIN
2272                RETURN p_arg + p_arg2;
2273            END"
2274            .run(&mut conn)
2275            .unwrap();
2276
2277            "SELECT f1(?, ?)"
2278                .with((100u8, 100u8))
2279                .run(&mut conn)
2280                .unwrap();
2281        }
2282
2283        #[test]
2284        fn issue_285() {
2285            let (tx, rx) = sync_channel::<()>(0);
2286
2287            let handle = std::thread::spawn(move || {
2288                let mut conn = Conn::new(get_opts()).unwrap();
2289                const INVALID_SQL: &str = r#"
2290                CREATE TEMPORARY TABLE IF NOT EXISTS `user_details` (
2291                    `user_id` int(11) NOT NULL AUTO_INCREMENT,
2292                    `username` varchar(255) DEFAULT NULL,
2293                    `first_name` varchar(50) DEFAULT NULL,
2294                    `last_name` varchar(50) DEFAULT NULL,
2295                    PRIMARY KEY (`user_id`)
2296                );
2297
2298                INSERT INTO `user_details` (`user_id`, `username`, `first_name`, `last_name`)
2299                VALUES (1, 'rogers63', 'david')
2300                "#;
2301
2302                conn.query_iter(INVALID_SQL).unwrap();
2303                tx.send(()).unwrap();
2304            });
2305
2306            match rx.recv_timeout(Duration::from_secs(100_000)) {
2307                Ok(_) => handle.join().unwrap(),
2308                Err(_) => panic!("test failed"),
2309            }
2310        }
2311
2312        #[test]
2313        fn should_work_with_named_params() {
2314            let mut conn = Conn::new(get_opts()).unwrap();
2315            {
2316                let stmt = conn.prep("SELECT :a, :b, :a, :c").unwrap();
2317                let result = conn
2318                    .exec_first(&stmt, params! {"a" => 1, "b" => 2, "c" => 3})
2319                    .unwrap()
2320                    .unwrap();
2321                assert_eq!((1_u8, 2_u8, 1_u8, 3_u8), result);
2322            }
2323
2324            let result = conn
2325                .exec_first(
2326                    "SELECT :a, :b, :a + :b, :c",
2327                    params! {
2328                        "a" => 1,
2329                        "b" => 2,
2330                        "c" => 3,
2331                    },
2332                )
2333                .unwrap()
2334                .unwrap();
2335            assert_eq!((1_u8, 2_u8, 3_u8, 3_u8), result);
2336        }
2337
2338        #[test]
2339        fn should_return_error_on_missing_named_parameter() {
2340            let mut conn = Conn::new(get_opts()).unwrap();
2341            let stmt = conn.prep("SELECT :a, :b, :a, :c, :d").unwrap();
2342            let result =
2343                conn.exec_first::<crate::Row, _, _>(&stmt, params! {"a" => 1, "b" => 2, "c" => 3,});
2344            match result {
2345                Err(DriverError(MissingNamedParameter(ref x))) if x == "d" => (),
2346                _ => panic!("MissingNamedParameter error expected"),
2347            }
2348        }
2349
2350        #[test]
2351        fn should_return_error_on_named_params_for_positional_statement() {
2352            let mut conn = Conn::new(get_opts()).unwrap();
2353            let stmt = conn.prep("SELECT ?, ?, ?, ?, ?").unwrap();
2354            let result = conn.exec_drop(&stmt, params! {"a" => 1, "b" => 2, "c" => 3,});
2355            match result {
2356                Err(DriverError(NamedParamsForPositionalQuery)) => (),
2357                _ => panic!("NamedParamsForPositionalQuery error expected"),
2358            }
2359        }
2360
2361        #[test]
2362        fn should_handle_tcp_connect_timeout() {
2363            use crate::error::{DriverError::ConnectTimeout, Error::DriverError};
2364
2365            let opts = OptsBuilder::from_opts(get_opts())
2366                .prefer_socket(false)
2367                .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)));
2368            assert!(Conn::new(opts).unwrap().ping().is_ok());
2369
2370            let opts = OptsBuilder::from_opts(get_opts())
2371                .prefer_socket(false)
2372                .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)))
2373                .ip_or_hostname(Some("192.168.255.255"));
2374            match Conn::new(opts).unwrap_err() {
2375                DriverError(ConnectTimeout) => {}
2376                err => panic!("Unexpected error: {}", err),
2377            }
2378        }
2379
2380        #[test]
2381        fn should_set_additional_capabilities() {
2382            use crate::consts::CapabilityFlags;
2383
2384            let opts = OptsBuilder::from_opts(get_opts())
2385                .additional_capabilities(CapabilityFlags::CLIENT_FOUND_ROWS);
2386
2387            let mut conn = Conn::new(opts).unwrap();
2388            conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl (a INT, b TEXT)")
2389                .unwrap();
2390            conn.query_drop("INSERT INTO mysql.tbl (a, b) VALUES (1, 'foo')")
2391                .unwrap();
2392            let result = conn
2393                .query_iter("UPDATE mysql.tbl SET b = 'foo' WHERE a = 1")
2394                .unwrap();
2395            assert_eq!(result.affected_rows(), 1);
2396        }
2397
2398        #[test]
2399        fn should_bind_before_connect() {
2400            let port = 28000 + (rand::random::<u16>() % 2000);
2401            let opts = OptsBuilder::from_opts(get_opts())
2402                .prefer_socket(false)
2403                .ip_or_hostname(Some("localhost"))
2404                .bind_address(Some(([127, 0, 0, 1], port)));
2405            let conn = Conn::new(opts).unwrap();
2406            let debug_format: String = format!("{:?}", conn);
2407            let expected_1 = format!("addr: V4(127.0.0.1:{})", port);
2408            let expected_2 = format!("addr: 127.0.0.1:{}", port);
2409            assert!(
2410                debug_format.contains(&expected_1) || debug_format.contains(&expected_2),
2411                "debug_format: {}",
2412                debug_format
2413            );
2414        }
2415
2416        #[test]
2417        fn should_bind_before_connect_with_timeout() {
2418            let port = 30000 + (rand::random::<u16>() % 2000);
2419            let opts = OptsBuilder::from_opts(get_opts())
2420                .prefer_socket(false)
2421                .ip_or_hostname(Some("localhost"))
2422                .bind_address(Some(([127, 0, 0, 1], port)))
2423                .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)));
2424            let mut conn = Conn::new(opts).unwrap();
2425            assert!(conn.ping().is_ok());
2426            let debug_format: String = format!("{:?}", conn);
2427            let expected_1 = format!("addr: V4(127.0.0.1:{})", port);
2428            let expected_2 = format!("addr: 127.0.0.1:{}", port);
2429            assert!(
2430                debug_format.contains(&expected_1) || debug_format.contains(&expected_2),
2431                "debug_format: {}",
2432                debug_format
2433            );
2434        }
2435
2436        #[test]
2437        fn should_not_cache_statements_if_stmt_cache_size_is_zero() {
2438            let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
2439            let mut conn = Conn::new(opts).unwrap();
2440
2441            let stmt1 = conn.prep("DO 1").unwrap();
2442            let stmt2 = conn.prep("DO 2").unwrap();
2443            let stmt3 = conn.prep("DO 3").unwrap();
2444
2445            conn.close(stmt1).unwrap();
2446            conn.close(stmt2).unwrap();
2447            conn.close(stmt3).unwrap();
2448
2449            let status: (Value, u8) = conn
2450                .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
2451                .unwrap()
2452                .unwrap();
2453            assert_eq!(status.1, 3);
2454        }
2455
2456        #[test]
2457        fn should_hold_stmt_cache_size_bounds() {
2458            let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
2459            let mut conn = Conn::new(opts).unwrap();
2460
2461            conn.prep("DO 1").unwrap();
2462            conn.prep("DO 2").unwrap();
2463            conn.prep("DO 3").unwrap();
2464            conn.prep("DO 1").unwrap();
2465            conn.prep("DO 4").unwrap();
2466            conn.prep("DO 3").unwrap();
2467            conn.prep("DO 5").unwrap();
2468            conn.prep("DO 6").unwrap();
2469
2470            let status: (String, usize) = conn
2471                .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close'")
2472                .unwrap()
2473                .unwrap();
2474
2475            assert_eq!(status.1, 3);
2476
2477            let mut order = conn
2478                .0
2479                .stmt_cache
2480                .iter()
2481                .map(|(_, entry)| &**entry.query.0.as_ref())
2482                .collect::<Vec<&[u8]>>();
2483            order.sort();
2484            assert_eq!(order, &[b"DO 3", b"DO 5", b"DO 6"]);
2485        }
2486
2487        #[test]
2488        fn should_handle_json_columns() {
2489            use crate::{Deserialized, Serialized};
2490            use serde::{Deserialize, Serialize};
2491            use serde_json::Value as Json;
2492            use std::str::FromStr;
2493
2494            #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
2495            pub struct DecTest {
2496                foo: String,
2497                quux: (u64, String),
2498            }
2499
2500            let decodable = DecTest {
2501                foo: "bar".into(),
2502                quux: (42, "hello".into()),
2503            };
2504
2505            let mut conn = Conn::new(get_opts()).unwrap();
2506            if conn
2507                .query_drop("CREATE TEMPORARY TABLE mysql.tbl(a VARCHAR(32), b JSON)")
2508                .is_err()
2509            {
2510                conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a VARCHAR(32), b TEXT)")
2511                    .unwrap();
2512            }
2513            conn.exec_drop(
2514                r#"INSERT INTO mysql.tbl VALUES ('hello', ?)"#,
2515                (Serialized(&decodable),),
2516            )
2517            .unwrap();
2518
2519            let (a, b): (String, Json) = conn
2520                .query_first("SELECT a, b FROM mysql.tbl")
2521                .unwrap()
2522                .unwrap();
2523            assert_eq!(
2524                (a, b),
2525                (
2526                    "hello".into(),
2527                    Json::from_str(r#"{"foo": "bar", "quux": [42, "hello"]}"#).unwrap()
2528                )
2529            );
2530
2531            let row = conn
2532                .exec_first("SELECT a, b FROM mysql.tbl WHERE a = ?", ("hello",))
2533                .unwrap()
2534                .unwrap();
2535            let (a, Deserialized(b)) = from_row(row);
2536            assert_eq!((a, b), (String::from("hello"), decodable));
2537        }
2538
2539        #[test]
2540        fn should_set_connect_attrs() {
2541            let opts = OptsBuilder::from_opts(
2542                get_opts().connect_attrs::<String, String>(Some(Default::default())),
2543            );
2544            let mut conn = Conn::new(opts).unwrap();
2545
2546            let support_connect_attrs = match (conn.0.server_version, conn.0.mariadb_server_version)
2547            {
2548                (Some(ref version), _) if *version >= (5, 6, 0) => true,
2549                (_, Some(ref version)) if *version >= (10, 0, 0) => true,
2550                _ => false,
2551            };
2552
2553            if support_connect_attrs {
2554                // MySQL >= 5.6 or MariaDB >= 10.0
2555
2556                if get_system_variable::<String>(&mut conn, "performance_schema") != "ON" {
2557                    panic!("The system variable `performance_schema` is off. Restart the MySQL server with `--performance_schema=on` to pass the test.");
2558                }
2559                let attrs_size: i32 =
2560                    get_system_variable(&mut conn, "performance_schema_session_connect_attrs_size");
2561                if (0..=128).contains(&attrs_size) {
2562                    panic!("The system variable `performance_schema_session_connect_attrs_size` is {}. Restart the MySQL server with `--performance_schema_session_connect_attrs_size=-1` to pass the test.", attrs_size);
2563                }
2564
2565                fn assert_connect_attrs(conn: &mut Conn, expected_values: &[(&str, &str)]) {
2566                    let mut actual_values = HashMap::new();
2567                    for row in conn.query_iter("SELECT attr_name, attr_value FROM performance_schema.session_account_connect_attrs WHERE processlist_id = connection_id()").unwrap() {
2568                        let (name, value) = from_row::<(String, String)>(row.unwrap());
2569                        actual_values.insert(name, value);
2570                    }
2571
2572                    for (name, value) in expected_values {
2573                        assert_eq!(actual_values.get(*name), Some(&value.to_string()));
2574                    }
2575                }
2576
2577                let pid = process::id().to_string();
2578                let prog_name = std::env::args_os()
2579                    .next()
2580                    .unwrap()
2581                    .to_string_lossy()
2582                    .into_owned();
2583                let mut expected_values = vec![
2584                    ("_client_name", "rust-mysql-simple"),
2585                    ("_client_version", env!("CARGO_PKG_VERSION")),
2586                    ("_os", env!("CARGO_CFG_TARGET_OS")),
2587                    ("_pid", &pid),
2588                    ("_platform", env!("CARGO_CFG_TARGET_ARCH")),
2589                    ("program_name", &prog_name),
2590                ];
2591
2592                // No connect attributes are added.
2593                assert_connect_attrs(&mut conn, &expected_values);
2594
2595                // Connect attributes are added.
2596                let opts = OptsBuilder::from_opts(get_opts());
2597                let mut connect_attrs = HashMap::with_capacity(3);
2598                connect_attrs.insert("foo", "foo val");
2599                connect_attrs.insert("bar", "bar val");
2600                connect_attrs.insert("program_name", "my program name");
2601                let mut conn = Conn::new(opts.connect_attrs(Some(connect_attrs))).unwrap();
2602                expected_values.pop(); // remove program_name at the last
2603                expected_values.push(("foo", "foo val"));
2604                expected_values.push(("bar", "bar val"));
2605                expected_values.push(("program_name", "my program name"));
2606                assert_connect_attrs(&mut conn, &expected_values);
2607            }
2608        }
2609
2610        // This test verifies that the metadata is correct with or without metadata caching, and that protocol
2611        // is not broken afterwards and data is read correctly. It doesn't test that the metadata is really cached
2612        // (if that is possible) and not received twice.
2613        #[test]
2614        fn test_metadata_caching() {
2615            use crate::consts::ColumnType;
2616            let mut conn = Conn::new(get_opts()).unwrap();
2617            if conn.0.mariadb_server_version.is_none() {
2618                return;
2619            }
2620
2621            conn.query_drop(
2622                r"CREATE TEMPORARY TABLE t_metadata_caching (
2623                    id INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
2624                    val VARCHAR(32) NOT NULL)",
2625            )
2626            .unwrap();
2627
2628            // Populating table with some data to verify that data still fetched correctly with cached metadata use
2629            let insert_stmt = conn
2630                .prep("INSERT INTO t_metadata_caching (val) VALUES (?)")
2631                .unwrap();
2632            let _ = conn.exec_drop(&insert_stmt, ("AAA",));
2633            let _ = conn.exec_drop(&insert_stmt, ("BB",));
2634            let mut ps = conn.prep("SELECT id, val FROM t_metadata_caching").unwrap();
2635
2636            let mut columns_from_prep = ps.columns();
2637            let mut metadata_from_prep: Vec<(String, ColumnType)> = columns_from_prep
2638                .iter()
2639                .map(|column| (column.name_str().to_string(), column.column_type()))
2640                .collect();
2641
2642            let mut query_result = conn.exec_iter(&ps, ()).unwrap();
2643            let mut columns_from_exec1 = query_result.columns();
2644            let mut metadata_from_exec1: Vec<(String, ColumnType)> = columns_from_exec1
2645                .as_ref()
2646                .iter()
2647                .map(|column| (column.name_str().to_string(), column.column_type()))
2648                .collect();
2649
2650            // Comparing and verifying metadata.
2651            assert_eq!(metadata_from_prep, metadata_from_exec1);
2652            assert_eq!(metadata_from_prep.len(), 2);
2653            assert_eq!(metadata_from_prep[0].0, "id");
2654            assert_eq!(metadata_from_prep[0].1, ColumnType::MYSQL_TYPE_LONG);
2655            assert_eq!(metadata_from_prep[1].0, "val");
2656            assert_eq!(metadata_from_prep[1].1, ColumnType::MYSQL_TYPE_VAR_STRING);
2657
2658            let fetched_rows: Vec<(i32, String)> = query_result
2659                .map(|row_result| crate::from_row(row_result.unwrap()))
2660                .collect();
2661
2662            let expected_rows = [(1, "AAA".to_string()), (2, "BB".to_string())];
2663            assert_eq!(fetched_rows.len(), expected_rows.len());
2664
2665            for (fetched, expected) in fetched_rows.iter().zip(expected_rows.iter()) {
2666                assert_eq!(fetched, expected);
2667            }
2668
2669            // Doing the same for exec_first. Technically it's internally the same as exec_iter,
2670            // but the test isn't supposed to know that and to test it
2671            ps = conn
2672                .prep("SELECT val FROM t_metadata_caching WHERE id = ?")
2673                .unwrap();
2674            columns_from_prep = ps.columns();
2675            metadata_from_prep = columns_from_prep
2676                .iter()
2677                .map(|column| (column.name_str().to_string(), column.column_type()))
2678                .collect();
2679            let single_row: Option<String> = conn.exec_first(&ps, (1,)).unwrap();
2680            if let Some(val) = single_row {
2681                assert_eq!(metadata_from_prep.len(), 1);
2682                assert_eq!(metadata_from_prep[0].0, "val");
2683                assert_eq!(metadata_from_prep[0].1, ColumnType::MYSQL_TYPE_VAR_STRING);
2684
2685                assert_eq!(val, "AAA".to_string());
2686            }
2687            // Testing the case when metadata is changed after execution
2688            ps = conn.prep("SELECT ?").unwrap();
2689
2690            columns_from_prep = ps.columns();
2691            metadata_from_prep = columns_from_prep
2692                .iter()
2693                .map(|column| (column.name_str().to_string(), column.column_type()))
2694                .collect();
2695
2696            // First query — server sends metadata because the type has changed
2697            query_result = conn.exec_iter(&ps, (12,)).unwrap();
2698            columns_from_exec1 = query_result.columns();
2699            metadata_from_exec1 = columns_from_exec1
2700                .as_ref()
2701                .iter()
2702                .map(|column| (column.name_str().to_string(), column.column_type()))
2703                .collect();
2704            let fetched_rows: Vec<i32> = query_result
2705                .map(|row_result| crate::from_row(row_result.unwrap()))
2706                .collect();
2707
2708            // Second query — server skips metadata packets
2709            query_result = conn.exec_iter(&ps, (42,)).unwrap();
2710            let columns_from_exec2 = query_result.columns();
2711            let metadata_from_exec2 = columns_from_exec2
2712                .as_ref()
2713                .iter()
2714                .map(|column| (column.name_str().to_string(), column.column_type()))
2715                .collect::<Vec<_>>();
2716            let fetched_rows2: Vec<i32> = query_result
2717                .map(|row_result| crate::from_row(row_result.unwrap()))
2718                .collect();
2719
2720            // Third query — server sends metadata because the type has changed
2721            query_result = conn.exec_iter(&ps, ("foo",)).unwrap();
2722            let columns_from_exec3 = query_result.columns();
2723            let metadata_from_exec3 = columns_from_exec3
2724                .as_ref()
2725                .iter()
2726                .map(|column| (column.name_str().to_string(), column.column_type()))
2727                .collect::<Vec<_>>();
2728            let fetched_rows3: Vec<String> = query_result
2729                .map(|row_result| crate::from_row(row_result.unwrap()))
2730                .collect();
2731
2732            // Comparing and verifying metadata.
2733            assert_eq!(metadata_from_exec1.len(), 1);
2734            assert_eq!(metadata_from_exec2.len(), 1);
2735            assert_eq!(metadata_from_exec3.len(), 1);
2736            assert_eq!(metadata_from_prep.len(), 1);
2737            assert_eq!(metadata_from_prep[0].0, "?");
2738            assert!(
2739                metadata_from_prep[0].1 == ColumnType::MYSQL_TYPE_NULL
2740                    || metadata_from_prep[0].1 == ColumnType::MYSQL_TYPE_VAR_STRING,
2741                "Expected MYSQL_TYPE_NULL(MariaDB) or MYSQL_TYPE_VAR_STRING(MySQL), got {:?}",
2742                metadata_from_prep[0].1
2743            );
2744            assert_eq!(metadata_from_exec1[0].0, "?");
2745            assert_eq!(metadata_from_exec1[0].1, ColumnType::MYSQL_TYPE_LONGLONG);
2746            assert_eq!(metadata_from_exec2[0].0, "?");
2747            assert_eq!(metadata_from_exec2[0].1, ColumnType::MYSQL_TYPE_LONGLONG);
2748            assert_eq!(metadata_from_exec3[0].0, "?");
2749            assert_eq!(metadata_from_exec3[0].1, ColumnType::MYSQL_TYPE_VAR_STRING);
2750
2751            assert_eq!(fetched_rows[0], 12);
2752            assert_eq!(fetched_rows2[0], 42);
2753            assert_eq!(fetched_rows3[0], "foo".to_owned());
2754        }
2755
2756        #[test]
2757        #[cfg(feature = "binlog")]
2758        fn should_read_binlog() -> crate::Result<()> {
2759            use std::{
2760                collections::HashMap, sync::mpsc::sync_channel, thread::spawn, time::Duration,
2761            };
2762
2763            fn gen_dummy_data() -> crate::Result<()> {
2764                let mut conn = Conn::new(get_opts())?;
2765
2766                "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)".run(&mut conn)?;
2767
2768                for i in 0_u8..100 {
2769                    "INSERT INTO customers(customer_id) VALUES (?)"
2770                        .with((i,))
2771                        .run(&mut conn)?;
2772                }
2773
2774                "DROP TABLE customers".run(&mut conn)?;
2775
2776                Ok(())
2777            }
2778
2779            fn get_conn() -> crate::Result<(Conn, Vec<u8>, u64)> {
2780                let mut conn = Conn::new(get_opts())?;
2781
2782                if let Ok(Some(gtid_mode)) =
2783                    "SELECT @@GLOBAL.GTID_MODE".first::<String, _>(&mut conn)
2784                {
2785                    if !gtid_mode.starts_with("ON") {
2786                        panic!(
2787                            "GTID_MODE is disabled \
2788                                (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)"
2789                        );
2790                    }
2791                }
2792
2793                let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn)?.unwrap();
2794                let filename = row.get(0).unwrap();
2795                let position = row.get(1).unwrap();
2796
2797                gen_dummy_data().unwrap();
2798                Ok((conn, filename, position))
2799            }
2800
2801            // iterate using COM_BINLOG_DUMP
2802            let (conn, filename, pos) = get_conn().unwrap();
2803            let is_mariadb = conn.0.mariadb_server_version.is_some();
2804
2805            let binlog_stream = conn
2806                .get_binlog_stream(BinlogRequest::new(12).with_filename(filename).with_pos(pos))
2807                .unwrap();
2808
2809            let mut events_num = 0;
2810            let (tx, rx) = sync_channel(0);
2811            spawn(move || {
2812                for event in binlog_stream {
2813                    tx.send(event).unwrap();
2814                }
2815            });
2816            let mut tmes = HashMap::new();
2817            while let Ok(event) = rx.recv_timeout(Duration::from_secs(1)) {
2818                let event = event.unwrap();
2819                events_num += 1;
2820
2821                // assert that event type is known
2822                event.header().event_type().unwrap();
2823
2824                // iterate over rows of an event
2825                match event.read_data()?.unwrap() {
2826                    EventData::TableMapEvent(tme) => {
2827                        tmes.insert(tme.table_id(), tme.into_owned());
2828                    }
2829                    EventData::RowsEvent(re) => {
2830                        for row in re.rows(&tmes[&re.table_id()]) {
2831                            row.unwrap();
2832                        }
2833                    }
2834                    _ => (),
2835                }
2836            }
2837            assert!(events_num > 0);
2838
2839            if !is_mariadb {
2840                // iterate using COM_BINLOG_DUMP_GTID
2841                let (conn, filename, pos) = get_conn().unwrap();
2842
2843                let binlog_stream = conn
2844                    .get_binlog_stream(
2845                        BinlogRequest::new(13)
2846                            .with_use_gtid(true)
2847                            .with_filename(filename)
2848                            .with_pos(pos),
2849                    )
2850                    .unwrap();
2851
2852                let mut events_num = 0;
2853                let (tx, rx) = sync_channel(0);
2854                spawn(move || {
2855                    for event in binlog_stream {
2856                        tx.send(event).unwrap();
2857                    }
2858                });
2859                let mut tmes = HashMap::new();
2860                while let Ok(event) = rx.recv_timeout(Duration::from_secs(1)) {
2861                    let event = event.unwrap();
2862                    events_num += 1;
2863
2864                    // assert that event type is known
2865                    event.header().event_type().unwrap();
2866
2867                    // iterate over rows of an event
2868                    match event.read_data()?.unwrap() {
2869                        EventData::TableMapEvent(tme) => {
2870                            tmes.insert(tme.table_id(), tme.into_owned());
2871                        }
2872                        EventData::RowsEvent(re) => {
2873                            for row in re.rows(&tmes[&re.table_id()]) {
2874                                row.unwrap();
2875                            }
2876                        }
2877                        _ => (),
2878                    }
2879                }
2880                assert!(events_num > 0);
2881            }
2882
2883            // iterate using COM_BINLOG_DUMP with BINLOG_DUMP_NON_BLOCK flag
2884            let (conn, filename, pos) = get_conn().unwrap();
2885
2886            let binlog_stream = conn
2887                .get_binlog_stream(
2888                    BinlogRequest::new(14)
2889                        .with_filename(filename)
2890                        .with_pos(pos)
2891                        .with_flags(crate::BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK),
2892                )
2893                .unwrap();
2894
2895            events_num = 0;
2896            for event in binlog_stream {
2897                let event = event.unwrap();
2898                events_num += 1;
2899                event.header().event_type().unwrap();
2900                event.read_data()?;
2901            }
2902            assert!(events_num > 0);
2903
2904            Ok(())
2905        }
2906    }
2907
2908    #[cfg(feature = "nightly")]
2909    mod bench {
2910        use test;
2911
2912        use crate::{params, prelude::*, test_misc::get_opts, Conn, Value::NULL};
2913
2914        #[bench]
2915        fn simple_exec(bencher: &mut test::Bencher) {
2916            let mut conn = Conn::new(get_opts()).unwrap();
2917            bencher.iter(|| {
2918                let _ = conn.query_drop("DO 1");
2919            })
2920        }
2921
2922        #[bench]
2923        fn prepared_exec(bencher: &mut test::Bencher) {
2924            let mut conn = Conn::new(get_opts()).unwrap();
2925            let stmt = conn.prep("DO 1").unwrap();
2926            bencher.iter(|| {
2927                let _ = conn.exec_drop(&stmt, ()).unwrap();
2928            })
2929        }
2930
2931        #[bench]
2932        fn prepare_and_exec(bencher: &mut test::Bencher) {
2933            let mut conn = Conn::new(get_opts()).unwrap();
2934            bencher.iter(|| {
2935                let stmt = conn.prep("SELECT ?").unwrap();
2936                let _ = conn.exec_drop(&stmt, (0,)).unwrap();
2937            })
2938        }
2939
2940        #[bench]
2941        fn simple_query_row(bencher: &mut test::Bencher) {
2942            let mut conn = Conn::new(get_opts()).unwrap();
2943            bencher.iter(|| {
2944                let _ = conn.query_drop("SELECT 1").unwrap();
2945            })
2946        }
2947
2948        #[bench]
2949        fn simple_prepared_query_row(bencher: &mut test::Bencher) {
2950            let mut conn = Conn::new(get_opts()).unwrap();
2951            let stmt = conn.prep("SELECT 1").unwrap();
2952            bencher.iter(|| {
2953                let _ = conn.exec_drop(&stmt, ()).unwrap();
2954            })
2955        }
2956
2957        #[bench]
2958        fn simple_prepared_query_row_with_param(bencher: &mut test::Bencher) {
2959            let mut conn = Conn::new(get_opts()).unwrap();
2960            let stmt = conn.prep("SELECT ?").unwrap();
2961            bencher.iter(|| {
2962                let _ = conn.exec_drop(&stmt, (0,)).unwrap();
2963            })
2964        }
2965
2966        #[bench]
2967        fn simple_prepared_query_row_with_named_param(bencher: &mut test::Bencher) {
2968            let mut conn = Conn::new(get_opts()).unwrap();
2969            let stmt = conn.prep("SELECT :a").unwrap();
2970            bencher.iter(|| {
2971                let _ = conn.exec_drop(&stmt, params! {"a" => 0}).unwrap();
2972            })
2973        }
2974
2975        #[bench]
2976        fn simple_prepared_query_row_with_5_params(bencher: &mut test::Bencher) {
2977            let mut conn = Conn::new(get_opts()).unwrap();
2978            let stmt = conn.prep("SELECT ?, ?, ?, ?, ?").unwrap();
2979            let params = (42i8, b"123456".to_vec(), 1.618f64, NULL, 1i8);
2980            bencher.iter(|| {
2981                let _ = conn.exec_drop(&stmt, &params).unwrap();
2982            })
2983        }
2984
2985        #[bench]
2986        fn simple_prepared_query_row_with_5_named_params(bencher: &mut test::Bencher) {
2987            let mut conn = Conn::new(get_opts()).unwrap();
2988            let stmt = conn
2989                .prep("SELECT :one, :two, :three, :four, :five")
2990                .unwrap();
2991            bencher.iter(|| {
2992                let _ = conn.exec_drop(
2993                    &stmt,
2994                    params! {
2995                        "one" => 42i8,
2996                        "two" => b"123456",
2997                        "three" => 1.618f64,
2998                        "four" => NULL,
2999                        "five" => 1i8,
3000                    },
3001                );
3002            })
3003        }
3004
3005        #[bench]
3006        fn select_large_string(bencher: &mut test::Bencher) {
3007            let mut conn = Conn::new(get_opts()).unwrap();
3008            bencher.iter(|| {
3009                let _ = conn.query_drop("SELECT REPEAT('A', 10000)").unwrap();
3010            })
3011        }
3012
3013        #[bench]
3014        fn select_prepared_large_string(bencher: &mut test::Bencher) {
3015            let mut conn = Conn::new(get_opts()).unwrap();
3016            let stmt = conn.prep("SELECT REPEAT('A', 10000)").unwrap();
3017            bencher.iter(|| {
3018                let _ = conn.exec_drop(&stmt, ()).unwrap();
3019            })
3020        }
3021
3022        #[bench]
3023        fn many_small_rows(bencher: &mut test::Bencher) {
3024            let mut conn = Conn::new(get_opts()).unwrap();
3025            conn.query_drop("CREATE TEMPORARY TABLE mysql.x (id INT)")
3026                .unwrap();
3027            for _ in 0..512 {
3028                conn.query_drop("INSERT INTO mysql.x VALUES (256)").unwrap();
3029            }
3030            let stmt = conn.prep("SELECT * FROM mysql.x").unwrap();
3031            bencher.iter(|| {
3032                let _ = conn.exec_drop(&stmt, ()).unwrap();
3033            });
3034        }
3035    }
3036}