1use bytes::{Buf, BufMut};
10#[cfg(feature = "binlog")]
11use mysql_common::packets::binlog_request::BinlogRequest;
12use mysql_common::{
13 constants::UTF8MB4_GENERAL_CI,
14 crypto,
15 io::{ParseBuf, ReadMysqlExt},
16 named_params::ParsedNamedParams,
17 packets::{
18 AuthPlugin, AuthSwitchRequest, Column, ComChangeUser, ComChangeUserMoreData, ComStmtClose,
19 ComStmtExecuteRequestBuilder, ComStmtSendLongData, CommonOkPacket, ErrPacket,
20 HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OkPacketKind,
21 OldAuthSwitchRequest, OldEofPacket, ResultSetTerminator, SessionStateInfo,
22 },
23 proto::{codec::Compression, sync_framed::MySyncFramed, MySerialize},
24};
25
26use mysql_common::{
27 constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8_GENERAL_CI},
28 packets::SslRequest,
29};
30
31use std::{
32 borrow::{Borrow, Cow},
33 collections::HashMap,
34 convert::TryFrom,
35 io::{self, Write as _},
36 mem,
37 ops::{Deref, DerefMut},
38 process,
39 sync::Arc,
40};
41
42#[cfg(unix)]
43use std::os::unix::io::{AsRawFd, RawFd};
44
45use crate::{
46 buffer_pool::{get_buffer, Buffer},
47 conn::{
48 local_infile::LocalInfile,
49 pool::{Pool, PooledConn},
50 query_result::{Binary, ResultSetMeta, Text},
51 stmt::{InnerStmt, Statement},
52 stmt_cache::StmtCache,
53 transaction::{AccessMode, TxOpts},
54 },
55 consts::{CapabilityFlags, Command, MariadbCapabilities, StatusFlags, MAX_PAYLOAD_LEN},
56 from_value, from_value_opt,
57 io::Stream,
58 prelude::*,
59 ChangeUserOpts,
60 DriverError::{
61 CleartextPluginDisabled, MismatchedStmtParams, NamedParamsForPositionalQuery,
62 OldMysqlPasswordDisabled, Protocol41NotSet, ReadOnlyTransNotSupported, SetupError,
63 UnexpectedPacket, UnknownAuthPlugin, UnsupportedProtocol,
64 },
65 Error::{self, DriverError, MySqlError},
66 LocalInfileHandler, Opts, OptsBuilder, Params, QueryResult, Result, Transaction,
67 Value::{self, Bytes, NULL},
68};
69
70use crate::DriverError::TlsNotSupported;
71use crate::SslOpts;
72
73#[cfg(feature = "binlog")]
74use self::binlog_stream::BinlogStream;
75
76#[cfg(feature = "binlog")]
77pub mod binlog_stream;
78pub mod local_infile;
79pub mod opts;
80pub mod pool;
81pub mod query;
82pub mod query_result;
83pub mod queryable;
84pub mod stmt;
85mod stmt_cache;
86pub mod transaction;
87
88#[derive(Debug)]
90pub enum ConnMut<'c, 't, 'tc> {
91 Mut(&'c mut Conn),
92 TxMut(&'t mut Transaction<'tc>),
93 Owned(Conn),
94 Pooled(PooledConn),
95}
96
97impl From<Conn> for ConnMut<'static, 'static, 'static> {
98 fn from(conn: Conn) -> Self {
99 ConnMut::Owned(conn)
100 }
101}
102
103impl From<PooledConn> for ConnMut<'static, 'static, 'static> {
104 fn from(conn: PooledConn) -> Self {
105 ConnMut::Pooled(conn)
106 }
107}
108
109impl<'a> From<&'a mut Conn> for ConnMut<'a, 'static, 'static> {
110 fn from(conn: &'a mut Conn) -> Self {
111 ConnMut::Mut(conn)
112 }
113}
114
115impl<'a> From<&'a mut PooledConn> for ConnMut<'a, 'static, 'static> {
116 fn from(conn: &'a mut PooledConn) -> Self {
117 ConnMut::Mut(conn.as_mut())
118 }
119}
120
121impl<'t, 'tc> From<&'t mut Transaction<'tc>> for ConnMut<'static, 't, 'tc> {
122 fn from(tx: &'t mut Transaction<'tc>) -> Self {
123 ConnMut::TxMut(tx)
124 }
125}
126
127impl TryFrom<&Pool> for ConnMut<'static, 'static, 'static> {
128 type Error = Error;
129
130 fn try_from(pool: &Pool) -> Result<Self> {
131 pool.get_conn().map(From::from)
132 }
133}
134
135impl Deref for ConnMut<'_, '_, '_> {
136 type Target = Conn;
137
138 fn deref(&self) -> &Conn {
139 match self {
140 ConnMut::Mut(conn) => conn,
141 ConnMut::TxMut(tx) => &tx.conn,
142 ConnMut::Owned(conn) => conn,
143 ConnMut::Pooled(conn) => conn.as_ref(),
144 }
145 }
146}
147
148impl DerefMut for ConnMut<'_, '_, '_> {
149 fn deref_mut(&mut self) -> &mut Conn {
150 match self {
151 ConnMut::Mut(conn) => conn,
152 ConnMut::TxMut(tx) => &mut tx.conn,
153 ConnMut::Owned(ref mut conn) => conn,
154 ConnMut::Pooled(ref mut conn) => conn.as_mut(),
155 }
156 }
157}
158
159pub(crate) enum ResultSetInfo {
160 Empty(OkPacket<'static>),
161 NonEmptyWithMeta(Vec<Column>),
162 NonEmptySkipMeta,
164}
165
166impl ResultSetInfo {
167 pub(crate) fn into_query_meta(self) -> ResultSetMeta {
168 match self {
169 ResultSetInfo::Empty(ok_packet) => ResultSetMeta::Empty(ok_packet),
170 ResultSetInfo::NonEmptyWithMeta(columns) => {
171 ResultSetMeta::NonEmptyWithMeta(columns.into())
172 }
173 ResultSetInfo::NonEmptySkipMeta => {
174 ResultSetMeta::NonEmptyWithMeta(Default::default())
176 }
177 }
178 }
179
180 pub(crate) fn into_statement_meta(self, conn: &Conn, stmt: &Statement) -> ResultSetMeta {
181 match self {
182 ResultSetInfo::Empty(ok_packet) => ResultSetMeta::Empty(ok_packet),
183 ResultSetInfo::NonEmptyWithMeta(columns) => {
184 stmt.update_columns_metadata(columns);
185 ResultSetMeta::NonEmptyWithMeta(stmt.columns())
186 }
187 ResultSetInfo::NonEmptySkipMeta => {
188 assert!(
189 conn.has_mariadb_capability(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA),
190 "metadata skipped but no MARIADB_CLIENT_CACHE_METADATA capability negotiated"
191 );
192 ResultSetMeta::NonEmptyWithMeta(stmt.columns())
193 }
194 }
195 }
196}
197
198#[derive(Debug)]
200struct ConnInner {
201 opts: Opts,
202 stream: Option<MySyncFramed<Stream>>,
203 stmt_cache: StmtCache,
204
205 server_version: Option<(u16, u16, u16)>,
207 mariadb_server_version: Option<(u16, u16, u16)>,
208
209 ok_packet: Option<OkPacket<'static>>,
211 capability_flags: CapabilityFlags,
212 mariadb_ext_capabilities: MariadbCapabilities,
213 connection_id: u32,
214 status_flags: StatusFlags,
215 character_set: u8,
216 last_command: u8,
217 connected: bool,
218 has_results: bool,
219 local_infile_handler: Option<LocalInfileHandler>,
220
221 auth_plugin: AuthPlugin<'static>,
222 nonce: Vec<u8>,
223
224 pub(crate) reset_upon_return: bool,
226}
227
228impl ConnInner {
229 fn empty(opts: Opts) -> Self {
230 ConnInner {
231 stmt_cache: StmtCache::new(opts.get_stmt_cache_size()),
232 stream: None,
233 capability_flags: CapabilityFlags::empty(),
234 status_flags: StatusFlags::empty(),
235 connection_id: 0u32,
236 character_set: 0u8,
237 ok_packet: None,
238 last_command: 0u8,
239 connected: false,
240 has_results: false,
241 server_version: None,
242 mariadb_server_version: None,
243 mariadb_ext_capabilities: MariadbCapabilities::empty(),
244 local_infile_handler: None,
245 auth_plugin: AuthPlugin::MysqlNativePassword,
246 nonce: Vec::new(),
247 reset_upon_return: opts.get_pool_opts().reset_connection(),
248
249 opts,
250 }
251 }
252}
253
254#[derive(Debug)]
256pub struct Conn(Box<ConnInner>);
257
258impl Conn {
259 const fn has_capability(&self, flag: CapabilityFlags) -> bool {
261 self.0.capability_flags.contains(flag)
262 }
263
264 const fn has_mariadb_capability(&self, flag: MariadbCapabilities) -> bool {
266 self.0.mariadb_ext_capabilities.contains(flag)
267 }
268
269 pub fn server_version(&self) -> (u16, u16, u16) {
271 self.0
272 .server_version
273 .or(self.0.mariadb_server_version)
274 .unwrap()
275 }
276
277 pub fn connection_id(&self) -> u32 {
279 self.0.connection_id
280 }
281
282 pub fn affected_rows(&self) -> u64 {
284 self.0
285 .ok_packet
286 .as_ref()
287 .map(OkPacket::affected_rows)
288 .unwrap_or_default()
289 }
290
291 pub fn last_insert_id(&self) -> u64 {
295 self.0
296 .ok_packet
297 .as_ref()
298 .and_then(OkPacket::last_insert_id)
299 .unwrap_or_default()
300 }
301
302 pub fn warnings(&self) -> u16 {
304 self.0
305 .ok_packet
306 .as_ref()
307 .map(OkPacket::warnings)
308 .unwrap_or_default()
309 }
310
311 pub fn info_ref(&self) -> &[u8] {
317 self.0
318 .ok_packet
319 .as_ref()
320 .and_then(OkPacket::info_ref)
321 .unwrap_or_default()
322 }
323
324 pub fn info_str(&self) -> Cow<str> {
330 self.0
331 .ok_packet
332 .as_ref()
333 .and_then(OkPacket::info_str)
334 .unwrap_or_default()
335 }
336
337 pub fn session_state_changes(&self) -> io::Result<Vec<SessionStateInfo<'_>>> {
338 self.0
339 .ok_packet
340 .as_ref()
341 .map(|ok| ok.session_state_info())
342 .transpose()
343 .map(Option::unwrap_or_default)
344 }
345
346 fn stream_ref(&self) -> &MySyncFramed<Stream> {
347 self.0.stream.as_ref().expect("incomplete connection")
348 }
349
350 fn stream_mut(&mut self) -> &mut MySyncFramed<Stream> {
351 self.0.stream.as_mut().expect("incomplete connection")
352 }
353
354 fn is_insecure(&self) -> bool {
355 self.stream_ref().get_ref().is_insecure()
356 }
357
358 fn is_socket(&self) -> bool {
359 self.stream_ref().get_ref().is_socket()
360 }
361
362 #[allow(unused_assignments)]
364 fn can_improved(&mut self) -> Result<Option<Opts>> {
365 if self.0.opts.get_prefer_socket() && self.0.opts.addr_is_loopback() {
366 let mut socket = None;
367 #[cfg(test)]
368 {
369 socket = self.0.opts.0.injected_socket.clone();
370 }
371 if socket.is_none() {
372 socket = self.get_system_var("socket")?.map(from_value::<String>);
373 }
374 if let Some(socket) = socket {
375 if self.0.opts.get_socket().is_none() {
376 let socket_opts = OptsBuilder::from_opts(self.0.opts.clone());
377 if !socket.is_empty() {
378 return Ok(Some(socket_opts.socket(Some(socket)).into()));
379 }
380 }
381 }
382 }
383 Ok(None)
384 }
385
386 pub fn new<T, E>(opts: T) -> Result<Conn>
388 where
389 Opts: TryFrom<T, Error = E>,
390 crate::Error: From<E>,
391 {
392 let opts = Opts::try_from(opts)?;
393 let mut conn = Conn(Box::new(ConnInner::empty(opts)));
394 conn.connect_stream()?;
395 conn.connect()?;
396 let mut conn = {
397 if let Some(new_opts) = conn.can_improved()? {
398 let mut improved_conn = Conn(Box::new(ConnInner::empty(new_opts)));
399 improved_conn
400 .connect_stream()
401 .and_then(|_| {
402 improved_conn.connect()?;
403 Ok(improved_conn)
404 })
405 .unwrap_or(conn)
406 } else {
407 conn
408 }
409 };
410 for cmd in conn.0.opts.get_init() {
411 conn.query_drop(cmd)?;
412 }
413 Ok(conn)
414 }
415
416 fn exec_com_reset_connection(&mut self) -> Result<()> {
417 self.write_command(Command::COM_RESET_CONNECTION, &[])?;
418 let packet = self.read_packet()?;
419 self.handle_ok::<CommonOkPacket>(&packet)?;
420 self.0.last_command = 0;
421 self.0.stmt_cache.clear();
422 Ok(())
423 }
424
425 fn exec_com_change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
426 opts.update_opts(&mut self.0.opts);
427 let com_change_user = ComChangeUser::new()
428 .with_user(self.0.opts.get_user().map(|x| x.as_bytes()))
429 .with_database(self.0.opts.get_db_name().map(|x| x.as_bytes()))
430 .with_auth_plugin_data(
431 self.0
432 .auth_plugin
433 .gen_data(self.0.opts.get_pass(), &self.0.nonce)
434 .as_deref(),
435 )
436 .with_more_data(Some(
437 ComChangeUserMoreData::new(if self.server_version() >= (5, 5, 3) {
438 UTF8MB4_GENERAL_CI
439 } else {
440 UTF8_GENERAL_CI
441 })
442 .with_auth_plugin(Some(self.0.auth_plugin.clone()))
443 .with_connect_attributes(self.0.opts.get_connect_attrs().cloned()),
444 ))
445 .into_owned();
446 self.write_command_raw(&com_change_user)?;
447 self.0.last_command = 0;
448 self.0.stmt_cache.clear();
449 self.continue_auth(false)
450 }
451
452 pub fn reset(&mut self) -> Result<()> {
466 let reset_result = match (self.0.server_version, self.0.mariadb_server_version) {
467 (Some(ref version), _) if *version > (5, 7, 3) => self.exec_com_reset_connection(),
468 (_, Some(ref version)) if *version >= (10, 2, 7) => self.exec_com_reset_connection(),
469 _ => return self.exec_com_change_user(ChangeUserOpts::DEFAULT),
470 };
471
472 match reset_result {
473 Ok(_) => (),
474 Err(crate::Error::MySqlError(_)) => {
475 self.exec_com_change_user(ChangeUserOpts::DEFAULT)?;
477 }
478 Err(e) => return Err(e),
479 }
480
481 for cmd in self.0.opts.get_init() {
482 self.query_drop(cmd)?;
483 }
484
485 Ok(())
486 }
487
488 pub fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> {
505 self.exec_com_change_user(opts)
506 }
507
508 fn switch_to_ssl(&mut self, ssl_opts: SslOpts) -> Result<()> {
509 let stream = self.0.stream.take().expect("incomplete conn");
510 let (in_buf, out_buf, codec, stream) = stream.destruct();
511 let stream = stream.make_secure(self.0.opts.get_host(), ssl_opts)?;
512 let stream = MySyncFramed::construct(in_buf, out_buf, codec, stream);
513 self.0.stream = Some(stream);
514 Ok(())
515 }
516
517 fn connect_stream(&mut self) -> Result<()> {
518 let opts = &self.0.opts;
519 let read_timeout = opts.get_read_timeout().cloned();
520 let write_timeout = opts.get_write_timeout().cloned();
521 let tcp_keepalive_time = opts.get_tcp_keepalive_time_ms();
522 #[cfg(any(target_os = "linux", target_os = "macos",))]
523 let tcp_keepalive_probe_interval_secs = opts.get_tcp_keepalive_probe_interval_secs();
524 #[cfg(any(target_os = "linux", target_os = "macos",))]
525 let tcp_keepalive_probe_count = opts.get_tcp_keepalive_probe_count();
526 #[cfg(target_os = "linux")]
527 let tcp_user_timeout = opts.get_tcp_user_timeout_ms();
528 let tcp_nodelay = opts.get_tcp_nodelay();
529 let tcp_connect_timeout = opts.get_tcp_connect_timeout();
530 let bind_address = opts.bind_address().cloned();
531 let stream = if let Some(socket) = opts.get_socket() {
532 Stream::connect_socket(socket, read_timeout, write_timeout)?
533 } else {
534 let port = opts.get_tcp_port();
535 let ip_or_hostname = match opts.get_host() {
536 url::Host::Domain(domain) => domain,
537 url::Host::Ipv4(ip) => ip.to_string(),
538 url::Host::Ipv6(ip) => ip.to_string(),
539 };
540 Stream::connect_tcp(
541 &ip_or_hostname,
542 port,
543 read_timeout,
544 write_timeout,
545 tcp_keepalive_time,
546 #[cfg(any(target_os = "linux", target_os = "macos",))]
547 tcp_keepalive_probe_interval_secs,
548 #[cfg(any(target_os = "linux", target_os = "macos",))]
549 tcp_keepalive_probe_count,
550 #[cfg(target_os = "linux")]
551 tcp_user_timeout,
552 tcp_nodelay,
553 tcp_connect_timeout,
554 bind_address,
555 )?
556 };
557 self.0.stream = Some(MySyncFramed::new(stream));
558 Ok(())
559 }
560
561 fn raw_read_packet(&mut self, buffer: &mut Vec<u8>) -> Result<()> {
562 if !self.stream_mut().next_packet(buffer)? {
563 Err(Error::server_disconnected())
564 } else {
565 Ok(())
566 }
567 }
568
569 fn read_packet(&mut self) -> Result<Buffer> {
570 loop {
571 let mut buffer = get_buffer();
572 match self.raw_read_packet(buffer.as_mut()) {
573 Ok(()) if buffer.first() == Some(&0xff) => {
574 match ParseBuf(&buffer).parse(self.0.capability_flags)? {
575 ErrPacket::Error(server_error) => {
576 self.handle_err();
577 return Err(MySqlError(From::from(server_error)));
578 }
579 ErrPacket::Progress(_progress_report) => {
580 continue;
582 }
583 }
584 }
585 Ok(()) => return Ok(buffer),
586 Err(e) => {
587 self.handle_err();
588 return Err(e);
589 }
590 }
591 }
592 }
593
594 fn drop_packet(&mut self) -> Result<()> {
595 self.read_packet().map(drop)
596 }
597
598 fn write_struct<T: MySerialize>(&mut self, s: &T) -> Result<()> {
599 let mut buf = get_buffer();
600 s.serialize(buf.as_mut());
601 self.write_packet(&mut &*buf)
602 }
603
604 fn write_packet<T: Buf>(&mut self, data: &mut T) -> Result<()> {
605 self.stream_mut().send(data)?;
606 Ok(())
607 }
608
609 fn handle_handshake(&mut self, hp: &HandshakePacket<'_>) {
610 self.0.capability_flags = hp.capabilities() & self.get_client_flags();
611 self.0.status_flags = hp.status_flags();
612 self.0.connection_id = hp.connection_id();
613 self.0.character_set = hp.default_collation();
614 self.0.server_version = hp.server_version_parsed();
615 self.0.mariadb_server_version = hp.maria_db_server_version_parsed();
616 if self.0.mariadb_server_version.is_some()
619 && !self
620 .0
621 .capability_flags
622 .contains(CapabilityFlags::CLIENT_LONG_PASSWORD)
623 {
624 self.0.mariadb_ext_capabilities =
625 hp.mariadb_ext_capabilities() & self.get_mariadb_client_flags();
626 }
627 }
628
629 fn handle_ok<'a, T: OkPacketKind>(
630 &mut self,
631 buffer: &'a Buffer,
632 ) -> crate::Result<OkPacket<'a>> {
633 let ok = ParseBuf(buffer)
634 .parse::<OkPacketDeserializer<T>>(self.0.capability_flags)?
635 .into_inner();
636 self.0.status_flags = ok.status_flags();
637 self.0.ok_packet = Some(ok.clone().into_owned());
638 Ok(ok)
639 }
640
641 fn handle_err(&mut self) {
642 self.0.status_flags = StatusFlags::empty();
643 self.0.has_results = false;
644 self.0.ok_packet = None;
645 }
646
647 fn more_results_exists(&self) -> bool {
648 self.0
649 .status_flags
650 .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS)
651 }
652
653 fn perform_auth_switch(&mut self, auth_switch_request: AuthSwitchRequest<'_>) -> Result<()> {
654 if matches!(
655 auth_switch_request.auth_plugin(),
656 AuthPlugin::MysqlOldPassword
657 ) && self.0.opts.get_secure_auth()
658 {
659 return Err(DriverError(OldMysqlPasswordDisabled));
660 }
661
662 if matches!(
663 auth_switch_request.auth_plugin(),
664 AuthPlugin::Other(Cow::Borrowed(b"mysql_clear_password"))
665 ) && !self.0.opts.get_enable_cleartext_plugin()
666 {
667 return Err(DriverError(CleartextPluginDisabled));
668 }
669
670 self.0.nonce = auth_switch_request.plugin_data().to_vec();
671 self.0.auth_plugin = auth_switch_request.auth_plugin().into_owned();
672 let plugin_data = match self.0.auth_plugin {
673 ref x @ AuthPlugin::MysqlOldPassword => {
674 if self.0.opts.get_secure_auth() {
675 return Err(DriverError(OldMysqlPasswordDisabled));
676 }
677 x.gen_data(self.0.opts.get_pass(), &self.0.nonce)
678 }
679 ref x @ AuthPlugin::MysqlNativePassword => {
680 x.gen_data(self.0.opts.get_pass(), &self.0.nonce)
681 }
682 ref x @ AuthPlugin::CachingSha2Password => {
683 x.gen_data(self.0.opts.get_pass(), &self.0.nonce)
684 }
685 ref x @ AuthPlugin::MysqlClearPassword => {
686 if !self.0.opts.get_enable_cleartext_plugin() {
687 return Err(DriverError(UnknownAuthPlugin(
688 "mysql_clear_password".into(),
689 )));
690 }
691
692 x.gen_data(self.0.opts.get_pass(), &self.0.nonce)
693 }
694 ref x @ AuthPlugin::Ed25519 => x.gen_data(self.0.opts.get_pass(), &self.0.nonce),
695 AuthPlugin::Other(_) => None,
696 };
697
698 if let Some(plugin_data) = plugin_data {
699 self.write_struct(&plugin_data.into_owned())?;
700 } else {
701 self.write_packet(&mut &[0_u8; 0][..])?;
702 }
703
704 self.continue_auth(true)
705 }
706
707 fn do_handshake(&mut self) -> Result<()> {
708 let payload = self.read_packet()?;
709 let handshake = ParseBuf(&payload).parse::<HandshakePacket>(())?;
710
711 if handshake.protocol_version() != 10u8 {
712 return Err(DriverError(UnsupportedProtocol(
713 handshake.protocol_version(),
714 )));
715 }
716
717 if !handshake
718 .capabilities()
719 .contains(CapabilityFlags::CLIENT_PROTOCOL_41)
720 {
721 return Err(DriverError(Protocol41NotSet));
722 }
723
724 self.handle_handshake(&handshake);
725
726 if self.is_insecure() {
727 if let Some(ssl_opts) = self.0.opts.get_ssl_opts().cloned() {
728 if !self.has_capability(CapabilityFlags::CLIENT_SSL) {
729 return Err(DriverError(TlsNotSupported));
730 } else {
731 self.do_ssl_request()?;
732 self.switch_to_ssl(ssl_opts)?;
733 }
734 }
735 }
736
737 self.0.nonce = {
739 let mut nonce = Vec::from(handshake.scramble_1_ref());
740 nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
741 nonce.resize(20, 0);
744 nonce
745 };
746
747 self.0.auth_plugin = match handshake.auth_plugin() {
751 Some(x @ AuthPlugin::CachingSha2Password) => x.into_owned(),
752 _ => AuthPlugin::MysqlNativePassword,
753 };
754
755 self.write_handshake_response()?;
756 self.continue_auth(false)?;
757
758 if self.has_capability(CapabilityFlags::CLIENT_COMPRESS) {
759 self.switch_to_compressed();
760 }
761
762 Ok(())
763 }
764
765 fn switch_to_compressed(&mut self) {
766 self.stream_mut()
767 .codec_mut()
768 .compress(Compression::default());
769 }
770
771 fn get_client_flags(&self) -> CapabilityFlags {
772 let mut client_flags = CapabilityFlags::CLIENT_PROTOCOL_41
773 | CapabilityFlags::CLIENT_SECURE_CONNECTION
774 | CapabilityFlags::CLIENT_LONG_PASSWORD
775 | CapabilityFlags::CLIENT_TRANSACTIONS
776 | CapabilityFlags::CLIENT_LOCAL_FILES
777 | CapabilityFlags::CLIENT_MULTI_STATEMENTS
778 | CapabilityFlags::CLIENT_MULTI_RESULTS
779 | CapabilityFlags::CLIENT_PS_MULTI_RESULTS
780 | CapabilityFlags::CLIENT_PLUGIN_AUTH
781 | (self.0.capability_flags & CapabilityFlags::CLIENT_LONG_FLAG);
782 if self.0.opts.get_compress().is_some() {
783 client_flags.insert(CapabilityFlags::CLIENT_COMPRESS);
784 }
785 if self.0.opts.get_connect_attrs().is_some() {
786 client_flags.insert(CapabilityFlags::CLIENT_CONNECT_ATTRS);
787 }
788 if let Some(db_name) = self.0.opts.get_db_name() {
789 if !db_name.is_empty() {
790 client_flags.insert(CapabilityFlags::CLIENT_CONNECT_WITH_DB);
791 }
792 }
793 if self.is_insecure() && self.0.opts.get_ssl_opts().is_some() {
794 client_flags.insert(CapabilityFlags::CLIENT_SSL);
795 }
796 client_flags | self.0.opts.get_additional_capabilities()
797 }
798
799 fn get_mariadb_client_flags(&self) -> MariadbCapabilities {
800 MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
801 }
802
803 fn connect_attrs(&self) -> Option<HashMap<String, String>> {
804 if let Some(attrs) = self.0.opts.get_connect_attrs() {
805 let program_name = match attrs.get("program_name") {
806 Some(program_name) => program_name.clone(),
807 None => {
808 let arg0 = std::env::args_os().next();
809 let arg0 = arg0.as_ref().map(|x| x.to_string_lossy());
810 arg0.unwrap_or_else(|| "".into()).into_owned()
811 }
812 };
813
814 let mut attrs_to_send = HashMap::new();
815
816 attrs_to_send.insert("_client_name".into(), "rust-mysql-simple".into());
817 attrs_to_send.insert("_client_version".into(), env!("CARGO_PKG_VERSION").into());
818 attrs_to_send.insert("_os".into(), env!("CARGO_CFG_TARGET_OS").into());
819 attrs_to_send.insert("_pid".into(), process::id().to_string());
820 attrs_to_send.insert("_platform".into(), env!("CARGO_CFG_TARGET_ARCH").into());
821 attrs_to_send.insert("program_name".into(), program_name);
822
823 for (name, value) in attrs.clone() {
824 attrs_to_send.insert(name, value);
825 }
826
827 Some(attrs_to_send)
828 } else {
829 None
830 }
831 }
832
833 fn do_ssl_request(&mut self) -> Result<()> {
834 let charset = if self.server_version() >= (5, 5, 3) {
835 UTF8MB4_GENERAL_CI
836 } else {
837 UTF8_GENERAL_CI
838 };
839
840 let ssl_request = SslRequest::new(
841 self.get_client_flags(),
842 DEFAULT_MAX_ALLOWED_PACKET as u32,
843 charset as u8,
844 );
845 self.write_struct(&ssl_request)
846 }
847
848 fn write_handshake_response(&mut self) -> Result<()> {
849 let auth_data = self
850 .0
851 .auth_plugin
852 .gen_data(self.0.opts.get_pass(), &self.0.nonce)
853 .map(|x| x.into_owned());
854
855 let handshake_response = HandshakeResponse::new(
856 auth_data.as_deref(),
857 self.0.server_version.unwrap_or((0, 0, 0)),
858 self.0.opts.get_user().map(str::as_bytes),
859 self.0.opts.get_db_name().map(str::as_bytes),
860 Some(self.0.auth_plugin.clone()),
861 self.0.capability_flags,
862 self.connect_attrs(),
863 self.0
864 .opts
865 .get_max_allowed_packet()
866 .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET) as u32,
867 )
868 .with_mariadb_ext_capabilities(self.0.mariadb_ext_capabilities);
869
870 let mut buf = get_buffer();
871 handshake_response.serialize(buf.as_mut());
872 self.write_packet(&mut &*buf)
873 }
874
875 fn continue_auth(&mut self, auth_switched: bool) -> Result<()> {
876 match self.0.auth_plugin {
877 AuthPlugin::CachingSha2Password => {
878 self.continue_caching_sha2_password_auth(auth_switched)?;
879 Ok(())
880 }
881 AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => {
882 self.continue_mysql_native_password_auth(auth_switched)?;
883 Ok(())
884 }
885 AuthPlugin::MysqlClearPassword => {
886 if !self.0.opts.get_enable_cleartext_plugin() {
887 return Err(DriverError(CleartextPluginDisabled));
888 }
889 self.continue_mysql_native_password_auth(auth_switched)?;
890 Ok(())
891 }
892 AuthPlugin::Ed25519 => {
893 self.continue_ed25519_auth(auth_switched)?;
894 Ok(())
895 }
896 AuthPlugin::Other(ref name) => {
897 let plugin_name = String::from_utf8_lossy(name).into();
898 Err(DriverError(UnknownAuthPlugin(plugin_name)))
899 }
900 }
901 }
902
903 fn continue_mysql_native_password_auth(&mut self, auth_switched: bool) -> Result<()> {
904 let payload = self.read_packet()?;
905
906 match payload[0] {
907 0x00 => self.handle_ok::<CommonOkPacket>(&payload).map(drop),
909 0xfe if !auth_switched => {
911 let auth_switch = if payload.len() > 1 {
912 ParseBuf(&payload).parse(())?
913 } else {
914 let _ = ParseBuf(&payload).parse::<OldAuthSwitchRequest>(())?;
915 AuthSwitchRequest::new("mysql_old_password".as_bytes(), &*self.0.nonce)
917 .into_owned()
918 };
919 self.perform_auth_switch(auth_switch)
920 }
921 _ => Err(DriverError(UnexpectedPacket)),
922 }
923 }
924
925 fn continue_caching_sha2_password_auth(&mut self, auth_switched: bool) -> Result<()> {
926 let payload = self.read_packet()?;
927
928 match payload[0] {
929 0x00 => {
930 Ok(())
932 }
933 0x01 => match payload[1] {
934 0x03 => {
935 let payload = self.read_packet()?;
936 self.handle_ok::<CommonOkPacket>(&payload).map(drop)
937 }
938 0x04 => {
939 if !self.is_insecure() || self.is_socket() {
940 let mut pass = self.0.opts.get_pass().map(Vec::from).unwrap_or_default();
941 pass.push(0);
942 self.write_packet(&mut pass.as_slice())?;
943 } else {
944 self.write_packet(&mut &[0x02][..])?;
945 let payload = self.read_packet()?;
946 let key = &payload[1..];
947 let mut pass = self.0.opts.get_pass().map(Vec::from).unwrap_or_default();
948 pass.push(0);
949 for (i, c) in pass.iter_mut().enumerate() {
950 *(c) ^= self.0.nonce[i % self.0.nonce.len()];
951 }
952 let encrypted_pass = crypto::encrypt(&pass, key);
953 self.write_packet(&mut encrypted_pass.as_slice())?;
954 }
955
956 let payload = self.read_packet()?;
957 self.handle_ok::<CommonOkPacket>(&payload).map(drop)
958 }
959 _ => Err(DriverError(UnexpectedPacket)),
960 },
961 0xfe if !auth_switched => {
962 let auth_switch_request = ParseBuf(&payload).parse(())?;
963 self.perform_auth_switch(auth_switch_request)
964 }
965 _ => Err(DriverError(UnexpectedPacket)),
966 }
967 }
968
969 fn continue_ed25519_auth(&mut self, auth_switched: bool) -> Result<()> {
970 let payload = self.read_packet()?;
971 match payload[0] {
972 0x00 => Ok(()),
974 0xfe if !auth_switched => {
975 let auth_switch_request = ParseBuf(&payload).parse(())?;
976 self.perform_auth_switch(auth_switch_request)
977 }
978 _ => Err(DriverError(UnexpectedPacket)),
979 }
980 }
981
982 fn reset_seq_id(&mut self) {
983 self.stream_mut().codec_mut().reset_seq_id();
984 }
985
986 fn sync_seq_id(&mut self) {
987 self.stream_mut().codec_mut().sync_seq_id();
988 }
989
990 fn write_command_raw<T: MySerialize>(&mut self, cmd: &T) -> Result<()> {
991 let mut buf = get_buffer();
992 cmd.serialize(buf.as_mut());
993 self.reset_seq_id();
994 debug_assert!(!buf.is_empty());
995 self.0.last_command = buf[0];
996 self.write_packet(&mut &*buf)
997 }
998
999 fn write_command(&mut self, cmd: Command, data: &[u8]) -> Result<()> {
1000 let mut buf = get_buffer();
1001 buf.as_mut().put_u8(cmd as u8);
1002 buf.as_mut().extend_from_slice(data);
1003
1004 self.reset_seq_id();
1005 self.0.last_command = buf[0];
1006 self.write_packet(&mut &*buf)
1007 }
1008
1009 fn send_long_data(&mut self, stmt_id: u32, params: &[Value]) -> Result<()> {
1010 for (i, value) in params.iter().enumerate() {
1011 if let Bytes(bytes) = value {
1012 let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6);
1013 let chunks = chunks.chain(if bytes.is_empty() {
1014 Some(&[][..])
1015 } else {
1016 None
1017 });
1018 for chunk in chunks {
1019 let cmd = ComStmtSendLongData::new(stmt_id, i as u16, Cow::Borrowed(chunk));
1020 self.write_command_raw(&cmd)?;
1021 }
1022 }
1023 }
1024
1025 Ok(())
1026 }
1027
1028 fn _execute(&mut self, stmt: &Statement, params: Params) -> Result<ResultSetInfo> {
1029 let exec_request = match ¶ms {
1030 Params::Empty => {
1031 if stmt.num_params() != 0 {
1032 return Err(DriverError(MismatchedStmtParams(stmt.num_params(), 0)));
1033 }
1034
1035 let (body, _) = ComStmtExecuteRequestBuilder::new(stmt.id()).build(&[]);
1036 body
1037 }
1038 Params::Positional(params) => {
1039 if stmt.num_params() != params.len() as u16 {
1040 return Err(DriverError(MismatchedStmtParams(
1041 stmt.num_params(),
1042 params.len(),
1043 )));
1044 }
1045
1046 let (body, as_long_data) =
1047 ComStmtExecuteRequestBuilder::new(stmt.id()).build(params);
1048
1049 if as_long_data {
1050 self.send_long_data(stmt.id(), params)?;
1051 }
1052
1053 body
1054 }
1055 Params::Named(_) => {
1056 if let Some(named_params) = stmt.named_params.as_ref() {
1057 return self._execute(stmt, params.into_positional(named_params)?);
1058 } else {
1059 return Err(DriverError(NamedParamsForPositionalQuery));
1060 }
1061 }
1062 };
1063 self.write_command_raw(&exec_request)?;
1064 self.handle_result_set()
1065 }
1066
1067 fn _start_transaction(&mut self, tx_opts: TxOpts) -> Result<()> {
1068 if let Some(i_level) = tx_opts.isolation_level() {
1069 self.query_drop(format!("SET TRANSACTION ISOLATION LEVEL {}", i_level))?;
1070 }
1071 if let Some(mode) = tx_opts.access_mode() {
1072 let supported = match (self.0.server_version, self.0.mariadb_server_version) {
1073 (Some(ref version), _) if *version >= (5, 6, 5) => true,
1074 (_, Some(ref version)) if *version >= (10, 0, 0) => true,
1075 _ => false,
1076 };
1077 if !supported {
1078 return Err(DriverError(ReadOnlyTransNotSupported));
1079 }
1080 match mode {
1081 AccessMode::ReadOnly => self.query_drop("SET TRANSACTION READ ONLY")?,
1082 AccessMode::ReadWrite => self.query_drop("SET TRANSACTION READ WRITE")?,
1083 }
1084 }
1085 if tx_opts.with_consistent_snapshot() {
1086 self.query_drop("START TRANSACTION WITH CONSISTENT SNAPSHOT")
1087 .unwrap();
1088 } else {
1089 self.query_drop("START TRANSACTION")?;
1090 };
1091 Ok(())
1092 }
1093
1094 fn send_local_infile(&mut self, file_name: &[u8]) -> Result<OkPacket<'static>> {
1095 {
1096 let mut buffer = [0_u8; LocalInfile::BUFFER_SIZE];
1097 let maybe_handler = self
1098 .0
1099 .local_infile_handler
1100 .clone()
1101 .or_else(|| self.0.opts.get_local_infile_handler().cloned());
1102 let mut local_infile = LocalInfile::new(&mut buffer, self);
1103 if let Some(handler) = maybe_handler {
1104 let handler_fn = &mut *handler.0.lock()?;
1108 handler_fn(file_name, &mut local_infile)?;
1109 }
1110 local_infile.flush()?;
1111 }
1112 self.write_packet(&mut &[][..])?;
1113 let payload = self.read_packet()?;
1114 let ok = self.handle_ok::<CommonOkPacket>(&payload)?;
1115 Ok(ok.into_owned())
1116 }
1117
1118 fn handle_result_set(&mut self) -> Result<ResultSetInfo> {
1119 if self.more_results_exists() {
1120 self.sync_seq_id();
1121 }
1122
1123 let pld = self.read_packet()?;
1124 match pld[0] {
1125 0x00 => {
1126 let ok = self.handle_ok::<CommonOkPacket>(&pld)?;
1127 Ok(ResultSetInfo::Empty(ok.into_owned()))
1128 }
1129 0xfb => match self.send_local_infile(&pld[1..]) {
1130 Ok(ok) => Ok(ResultSetInfo::Empty(ok)),
1131 Err(err) => Err(err),
1132 },
1133 _ => {
1134 let mut reader = &pld[..];
1135 let column_count = reader.read_lenenc_int()?;
1136
1137 let mut columns: Vec<Column> = Vec::new();
1138
1139 let output = if !(self
1141 .has_mariadb_capability(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA)
1142 && reader.first().copied() == Some(0x00))
1143 {
1144 columns.reserve(column_count as usize);
1145 for _ in 0..column_count {
1146 let pld = self.read_packet()?;
1147 let column = ParseBuf(&pld).parse(())?;
1148 columns.push(column);
1149 }
1150
1151 ResultSetInfo::NonEmptyWithMeta(columns)
1152 } else {
1153 ResultSetInfo::NonEmptySkipMeta
1154 };
1155
1156 self.drop_packet()?;
1158 self.0.has_results = column_count > 0;
1159
1160 Ok(output)
1161 }
1162 }
1163 }
1164
1165 fn _query(&mut self, query: &str) -> Result<ResultSetMeta> {
1166 self.write_command(Command::COM_QUERY, query.as_bytes())?;
1167 let info = self.handle_result_set()?;
1168 let meta = info.into_query_meta();
1169 Ok(meta)
1170 }
1171
1172 pub fn ping(&mut self) -> Result<(), Error> {
1175 self.write_command(Command::COM_PING, &[])?;
1176 self.drop_packet()
1177 }
1178
1179 pub fn select_db(&mut self, schema: &str) -> Result<(), Error> {
1182 self.write_command(Command::COM_INIT_DB, schema.as_bytes())?;
1183 self.drop_packet()
1184 }
1185
1186 pub fn start_transaction(&mut self, tx_opts: TxOpts) -> Result<Transaction> {
1189 self._start_transaction(tx_opts)?;
1190 Ok(Transaction::new(self.into()))
1191 }
1192
1193 fn _true_prepare(&mut self, query: &[u8]) -> Result<InnerStmt> {
1194 self.write_command(Command::COM_STMT_PREPARE, query)?;
1195 let pld = self.read_packet()?;
1196 let mut stmt = ParseBuf(&pld).parse::<InnerStmt>(self.connection_id())?;
1197 if stmt.num_params() > 0 {
1198 let mut params: Vec<Column> = Vec::with_capacity(stmt.num_params() as usize);
1199 for _ in 0..stmt.num_params() {
1200 let pld = self.read_packet()?;
1201 params.push(ParseBuf(&pld).parse(())?);
1202 }
1203 stmt = stmt.with_params(Some(params));
1204 self.drop_packet()?;
1205 }
1206 if stmt.num_columns() > 0 {
1207 let mut columns: Vec<Column> = Vec::with_capacity(stmt.num_columns() as usize);
1208 for _ in 0..stmt.num_columns() {
1209 let pld = self.read_packet()?;
1210 columns.push(ParseBuf(&pld).parse(())?);
1211 }
1212 stmt = stmt.with_columns(Some(columns));
1213 self.drop_packet()?;
1214 }
1215 Ok(stmt)
1216 }
1217
1218 fn _prepare(&mut self, query: &[u8]) -> Result<Arc<InnerStmt>> {
1219 if let Some(entry) = self.0.stmt_cache.by_query(query) {
1220 return Ok(entry.stmt.clone());
1221 }
1222
1223 let inner_st = Arc::new(self._true_prepare(query)?);
1224
1225 if let Some(old_stmt) = self
1226 .0
1227 .stmt_cache
1228 .put(Arc::new(query.into()), inner_st.clone())
1229 {
1230 self.close(Statement::new(old_stmt, None))?;
1231 }
1232
1233 Ok(inner_st)
1234 }
1235
1236 fn connect(&mut self) -> Result<()> {
1237 if self.0.connected {
1238 return Ok(());
1239 }
1240 self.do_handshake()
1241 .and_then(|_| match self.0.opts.get_max_allowed_packet() {
1242 Some(x) => Ok(x),
1243 None => Ok(from_value_opt::<usize>(
1244 self.get_system_var("max_allowed_packet")?.unwrap_or(NULL),
1245 )
1246 .unwrap_or(0)),
1247 })
1248 .and_then(|max_allowed_packet| {
1249 if max_allowed_packet == 0 {
1250 Err(DriverError(SetupError))
1251 } else {
1252 self.stream_mut().codec_mut().max_allowed_packet = max_allowed_packet;
1253 self.0.connected = true;
1254 Ok(())
1255 }
1256 })
1257 }
1258
1259 fn get_system_var(&mut self, name: &str) -> Result<Option<Value>> {
1260 self.query_first(format!("SELECT @@{}", name))
1261 }
1262
1263 fn next_row_packet(&mut self) -> Result<Option<Buffer>> {
1264 if !self.0.has_results {
1265 return Ok(None);
1266 }
1267
1268 let pld = self.read_packet()?;
1269
1270 if self.has_capability(CapabilityFlags::CLIENT_DEPRECATE_EOF) {
1271 if pld[0] == 0xfe && pld.len() < MAX_PAYLOAD_LEN {
1272 self.0.has_results = false;
1273 self.handle_ok::<ResultSetTerminator>(&pld)?;
1274 return Ok(None);
1275 }
1276 } else if pld[0] == 0xfe && pld.len() < 8 {
1277 self.0.has_results = false;
1278 self.handle_ok::<OldEofPacket>(&pld)?;
1279 return Ok(None);
1280 }
1281
1282 Ok(Some(pld))
1283 }
1284
1285 fn has_stmt(&self, query: &[u8]) -> bool {
1286 self.0.stmt_cache.contains_query(query)
1287 }
1288
1289 pub fn set_local_infile_handler(&mut self, handler: Option<LocalInfileHandler>) {
1296 self.0.local_infile_handler = handler;
1297 }
1298
1299 pub fn no_backslash_escape(&self) -> bool {
1300 self.0
1301 .status_flags
1302 .contains(StatusFlags::SERVER_STATUS_NO_BACKSLASH_ESCAPES)
1303 }
1304
1305 #[cfg(feature = "binlog")]
1306 fn register_as_slave(&mut self, server_id: u32) -> Result<()> {
1307 use mysql_common::packets::ComRegisterSlave;
1308
1309 self.query_drop("SET @master_binlog_checksum='ALL'")?;
1310 self.write_command_raw(&ComRegisterSlave::new(server_id))?;
1311
1312 self.read_packet()?;
1314
1315 Ok(())
1316 }
1317
1318 #[cfg(feature = "binlog")]
1319 fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> {
1320 self.register_as_slave(request.server_id())?;
1321 self.write_command_raw(&request.as_cmd())?;
1322 Ok(())
1323 }
1324
1325 #[cfg(feature = "binlog")]
1331 #[cfg_attr(docsrs, doc(cfg(feature = "binlog")))]
1332 pub fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result<BinlogStream> {
1333 self.request_binlog(request)?;
1334 Ok(BinlogStream::new(self))
1335 }
1336
1337 fn cleanup_for_pool(&mut self) -> Result<()> {
1338 self.set_local_infile_handler(None);
1339 if self.0.reset_upon_return {
1340 self.reset()?;
1341 }
1342
1343 self.0.reset_upon_return = self.0.opts.get_pool_opts().reset_connection();
1344
1345 Ok(())
1346 }
1347}
1348
1349#[cfg(unix)]
1350impl AsRawFd for Conn {
1351 fn as_raw_fd(&self) -> RawFd {
1352 self.stream_ref().get_ref().as_raw_fd()
1353 }
1354}
1355
1356impl Queryable for Conn {
1357 fn query_iter<T: AsRef<str>>(&mut self, query: T) -> Result<QueryResult<'_, '_, '_, Text>> {
1358 let meta = self._query(query.as_ref())?;
1359 Ok(QueryResult::new(ConnMut::Mut(self), meta))
1360 }
1361
1362 fn prep<T: AsRef<str>>(&mut self, query: T) -> Result<Statement> {
1363 let query = query.as_ref();
1364 let parsed = ParsedNamedParams::parse(query.as_bytes())?;
1365 let named_params: Vec<Vec<u8>> =
1366 parsed.params().iter().map(|param| param.to_vec()).collect();
1367 let named_params = if named_params.is_empty() {
1368 None
1369 } else {
1370 Some(named_params)
1371 };
1372 self._prepare(parsed.borrow().query())
1373 .map(|inner| Statement::new(inner, named_params))
1374 }
1375
1376 fn close(&mut self, stmt: Statement) -> Result<()> {
1377 self.0.stmt_cache.remove(stmt.id());
1378 let cmd = ComStmtClose::new(stmt.id());
1379 self.write_command_raw(&cmd)
1380 }
1381
1382 fn exec_iter<S, P>(&mut self, stmt: S, params: P) -> Result<QueryResult<'_, '_, '_, Binary>>
1383 where
1384 S: AsStatement,
1385 P: Into<Params>,
1386 {
1387 let statement = stmt.as_statement(self)?;
1388 let info = self._execute(&statement, params.into())?;
1389 let meta = info.into_statement_meta(&*self, &statement);
1390 Ok(QueryResult::new(ConnMut::Mut(self), meta))
1391 }
1392}
1393
1394impl Drop for Conn {
1395 fn drop(&mut self) {
1396 let stmt_cache = mem::replace(&mut self.0.stmt_cache, StmtCache::new(0));
1397
1398 for (_, entry) in stmt_cache.into_iter() {
1399 let _ = self.close(Statement::new(entry.stmt, None));
1400 }
1401
1402 if self.0.stream.is_some() {
1403 let _ = self.write_command(Command::COM_QUIT, &[]);
1404 }
1405 }
1406}
1407
1408#[cfg(test)]
1409#[allow(non_snake_case)]
1410mod test {
1411 mod my_conn {
1412 use std::{
1413 collections::HashMap,
1414 io::Write,
1415 process,
1416 sync::mpsc::{channel, sync_channel},
1417 thread::spawn,
1418 time::Duration,
1419 };
1420
1421 #[cfg(feature = "binlog")]
1422 use mysql_common::{binlog::events::EventData, packets::binlog_request::BinlogRequest};
1423 use rand::Fill;
1424 #[cfg(feature = "time")]
1425 use time::PrimitiveDateTime;
1426
1427 use crate::{
1428 conn::ConnInner,
1429 from_row, from_value, params,
1430 prelude::*,
1431 test_misc::get_opts,
1432 Conn,
1433 DriverError::{MissingNamedParameter, NamedParamsForPositionalQuery},
1434 Error::DriverError,
1435 LocalInfileHandler, Opts, OptsBuilder, Pool, TxOpts,
1436 Value::{self, Bytes, Date, Float, Int, NULL},
1437 };
1438
1439 fn get_system_variable<T>(conn: &mut Conn, name: &str) -> T
1440 where
1441 T: FromValue,
1442 {
1443 conn.query_first::<(String, T), _>(format!("show variables like '{}'", name))
1444 .unwrap()
1445 .unwrap()
1446 .1
1447 }
1448
1449 #[test]
1450 fn should_connect() {
1451 let mut conn = Conn::new(get_opts()).unwrap();
1452
1453 let mode: String = conn
1454 .query_first("SELECT @@GLOBAL.sql_mode")
1455 .unwrap()
1456 .unwrap();
1457 assert!(mode.contains("TRADITIONAL"));
1458 assert!(conn.ping().is_ok());
1459
1460 if crate::test_misc::test_compression() {
1461 assert!(format!("{:?}", conn.0.stream).contains("Compression"));
1462 }
1463
1464 if crate::test_misc::test_ssl() {
1465 assert!(!conn.is_insecure());
1466 }
1467 }
1468
1469 #[test]
1470 fn mysql_async_issue_107() -> crate::Result<()> {
1471 let mut conn = Conn::new(get_opts())?;
1472 conn.query_drop(
1473 r"CREATE TEMPORARY TABLE mysql.issue (
1474 a BIGINT(20) UNSIGNED,
1475 b VARBINARY(16),
1476 c BINARY(32),
1477 d BIGINT(20) UNSIGNED,
1478 e BINARY(32)
1479 )",
1480 )?;
1481 conn.query_drop(
1482 r"INSERT INTO mysql.issue VALUES (
1483 0,
1484 0xC066F966B0860000,
1485 0x7939DA98E524C5F969FC2DE8D905FD9501EBC6F20001B0A9C941E0BE6D50CF44,
1486 0,
1487 ''
1488 ), (
1489 1,
1490 '',
1491 0x076311DF4D407B0854371BA13A5F3FB1A4555AC22B361375FD47B263F31822F2,
1492 0,
1493 ''
1494 )",
1495 )?;
1496
1497 let q = "SELECT b, c, d, e FROM mysql.issue";
1498 let result = conn.query_iter(q)?;
1499
1500 let loaded_structs = result
1501 .map(|row| crate::from_row::<(Vec<u8>, Vec<u8>, u64, Vec<u8>)>(row.unwrap()))
1502 .collect::<Vec<_>>();
1503
1504 assert_eq!(loaded_structs.len(), 2);
1505
1506 Ok(())
1507 }
1508
1509 #[test]
1510 fn query_traits() -> Result<(), Box<dyn std::error::Error>> {
1511 macro_rules! test_query {
1512 ($conn : expr) => {
1513 "CREATE TABLE IF NOT EXISTS tmplak (a INT)"
1514 .run($conn)
1515 .unwrap();
1516 "DELETE FROM tmplak".run($conn).unwrap();
1517
1518 "INSERT INTO tmplak (a) VALUES (?)"
1519 .with((42,))
1520 .run($conn)
1521 .unwrap();
1522
1523 "INSERT INTO tmplak (a) VALUES (?)"
1524 .with((43..=44).map(|x| (x,)))
1525 .batch($conn)?;
1526
1527 let first: Option<u8> = "SELECT a FROM tmplak LIMIT 1".first($conn).unwrap();
1528 assert_eq!(first, Some(42), "first text");
1529
1530 let first: Option<u8> = "SELECT a FROM tmplak LIMIT 1"
1531 .with(())
1532 .first($conn)
1533 .unwrap();
1534 assert_eq!(first, Some(42), "first bin");
1535
1536 let count = "SELECT a FROM tmplak".run($conn).unwrap().count();
1537 assert_eq!(count, 3, "run text");
1538
1539 let count = "SELECT a FROM tmplak".with(()).run($conn).unwrap().count();
1540 assert_eq!(count, 3, "run bin");
1541
1542 let all: Vec<u8> = "SELECT a FROM tmplak".fetch($conn).unwrap();
1543 assert_eq!(all, vec![42, 43, 44], "fetch text");
1544
1545 let all: Vec<u8> = "SELECT a FROM tmplak".with(()).fetch($conn).unwrap();
1546 assert_eq!(all, vec![42, 43, 44], "fetch bin");
1547
1548 let mapped = "SELECT a FROM tmplak".map($conn, |x: u8| x + 1).unwrap();
1549 assert_eq!(mapped, vec![43, 44, 45], "map text");
1550
1551 let mapped = "SELECT a FROM tmplak"
1552 .with(())
1553 .map($conn, |x: u8| x + 1)
1554 .unwrap();
1555 assert_eq!(mapped, vec![43, 44, 45], "map bin");
1556
1557 let sum = "SELECT a FROM tmplak"
1558 .fold($conn, 0_u8, |acc, x: u8| acc + x)
1559 .unwrap();
1560 assert_eq!(sum, 42 + 43 + 44, "fold text");
1561
1562 let sum = "SELECT a FROM tmplak"
1563 .with(())
1564 .fold($conn, 0_u8, |acc, x: u8| acc + x)
1565 .unwrap();
1566 assert_eq!(sum, 42 + 43 + 44, "fold bin");
1567
1568 "DROP TABLE tmplak".run($conn).unwrap();
1569 };
1570 }
1571
1572 let mut conn = Conn::new(get_opts())?;
1573
1574 let mut tx = conn.start_transaction(TxOpts::default())?;
1575 test_query!(&mut tx);
1576 tx.rollback()?;
1577
1578 test_query!(&mut conn);
1579
1580 let pool = Pool::new(get_opts())?;
1581 let mut pooled_conn = pool.get_conn()?;
1582
1583 let mut tx = pool.start_transaction(TxOpts::default())?;
1584 test_query!(&mut tx);
1585 tx.rollback()?;
1586
1587 test_query!(&mut pooled_conn);
1588
1589 Ok(())
1590 }
1591
1592 #[test]
1593 #[should_panic(expected = "Could not connect to address")]
1594 fn should_fail_on_wrong_socket_path() {
1595 let opts = OptsBuilder::from_opts(get_opts()).socket(Some("/foo/bar/baz"));
1596 let _ = Conn::new(opts).unwrap();
1597 }
1598
1599 #[test]
1600 fn should_fallback_to_tcp_if_cant_switch_to_socket() {
1601 let mut opts = Opts::from(get_opts());
1602 opts.0.injected_socket = Some("/foo/bar/baz".into());
1603 let _ = Conn::new(opts).unwrap();
1604 }
1605
1606 #[test]
1607 fn should_connect_with_database() {
1608 const DB_NAME: &str = "mysql";
1609
1610 let opts = OptsBuilder::from_opts(get_opts()).db_name(Some(DB_NAME));
1611
1612 let mut conn = Conn::new(opts).unwrap();
1613
1614 let db_name: String = conn.query_first("SELECT DATABASE()").unwrap().unwrap();
1615 assert_eq!(db_name, DB_NAME);
1616 }
1617
1618 #[test]
1619 fn should_connect_by_hostname() {
1620 let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost"));
1621 let mut conn = Conn::new(opts).unwrap();
1622 assert!(conn.ping().is_ok());
1623 }
1624
1625 #[test]
1626 fn should_select_db() {
1627 const DB_NAME: &str = "t_select_db";
1628
1629 let mut conn = Conn::new(get_opts()).unwrap();
1630 conn.query_drop(format!("CREATE DATABASE IF NOT EXISTS {}", DB_NAME))
1631 .unwrap();
1632 assert!(conn.select_db(DB_NAME).is_ok());
1633
1634 let db_name: String = conn.query_first("SELECT DATABASE()").unwrap().unwrap();
1635 assert_eq!(db_name, DB_NAME);
1636
1637 conn.query_drop(format!("DROP DATABASE {}", DB_NAME))
1638 .unwrap();
1639 }
1640
1641 #[test]
1642 fn should_execute_queries_and_parse_results() {
1643 type TestRow = (String, String, String, String, String, String);
1644
1645 const CREATE_QUERY: &str = r"CREATE TEMPORARY TABLE mysql.tbl
1646 (id SERIAL, a TEXT, b INT, c INT UNSIGNED, d DATE, e FLOAT)";
1647 const INSERT_QUERY_1: &str = r"INSERT
1648 INTO mysql.tbl(a, b, c, d, e)
1649 VALUES ('hello', -123, 123, '2014-05-05', 123.123)";
1650 const INSERT_QUERY_2: &str = r"INSERT
1651 INTO mysql.tbl(a, b, c, d, e)
1652 VALUES ('world', -321, 321, '2014-06-06', 321.321)";
1653
1654 let mut conn = Conn::new(get_opts()).unwrap();
1655
1656 conn.query_drop(CREATE_QUERY).unwrap();
1657 assert_eq!(conn.affected_rows(), 0);
1658 assert_eq!(conn.last_insert_id(), 0);
1659
1660 conn.query_drop(INSERT_QUERY_1).unwrap();
1661 assert_eq!(conn.affected_rows(), 1);
1662 assert_eq!(conn.last_insert_id(), 1);
1663
1664 conn.query_drop(INSERT_QUERY_2).unwrap();
1665 assert_eq!(conn.affected_rows(), 1);
1666 assert_eq!(conn.last_insert_id(), 2);
1667
1668 conn.query_drop("SELECT * FROM nonexistent").unwrap_err();
1669 conn.query_iter("SELECT * FROM mysql.tbl").unwrap(); conn.query_drop("UPDATE mysql.tbl SET a = 'foo'").unwrap();
1672 assert_eq!(conn.affected_rows(), 2);
1673 assert_eq!(conn.last_insert_id(), 0);
1674
1675 assert!(conn
1676 .query_first::<TestRow, _>("SELECT * FROM mysql.tbl WHERE a = 'bar'")
1677 .unwrap()
1678 .is_none());
1679
1680 let rows: Vec<TestRow> = conn.query("SELECT * FROM mysql.tbl").unwrap();
1681 assert_eq!(
1682 rows,
1683 vec![
1684 (
1685 "1".into(),
1686 "foo".into(),
1687 "-123".into(),
1688 "123".into(),
1689 "2014-05-05".into(),
1690 "123.123".into()
1691 ),
1692 (
1693 "2".into(),
1694 "foo".into(),
1695 "-321".into(),
1696 "321".into(),
1697 "2014-06-06".into(),
1698 "321.321".into()
1699 )
1700 ]
1701 );
1702 }
1703
1704 #[test]
1705 fn should_parse_large_text_result() {
1706 let mut conn = Conn::new(get_opts()).unwrap();
1707 let value: Value = conn
1708 .query_first("SELECT REPEAT('A', 20000000)")
1709 .unwrap()
1710 .unwrap();
1711 assert_eq!(
1712 value,
1713 Bytes(std::iter::repeat_n(b'A', 20_000_000).collect())
1714 );
1715 }
1716
1717 #[test]
1718 fn should_execute_statements_and_parse_results() {
1719 const CREATE_QUERY: &str = r"CREATE TEMPORARY TABLE
1720 mysql.tbl (a TEXT, b INT, c INT UNSIGNED, d DATE, e FLOAT)";
1721 const INSERT_STMT: &str = r"INSERT
1722 INTO mysql.tbl (a, b, c, d, e)
1723 VALUES (?, ?, ?, ?, ?)";
1724
1725 type RowType = (Value, Value, Value, Value, Value);
1726
1727 let row1 = (
1728 Bytes(b"hello".to_vec()),
1729 Int(-123_i64),
1730 Int(123_i64),
1731 Date(2014_u16, 5_u8, 5_u8, 0_u8, 0_u8, 0_u8, 0_u32),
1732 Float(123.123_f32),
1733 );
1734 let row2 = (Bytes(b"".to_vec()), NULL, NULL, NULL, Float(321.321_f32));
1735
1736 let mut conn = Conn::new(get_opts()).unwrap();
1737 conn.query_drop(CREATE_QUERY).unwrap();
1738
1739 let insert_stmt = conn.prep(INSERT_STMT).unwrap();
1740 assert_eq!(insert_stmt.connection_id(), conn.connection_id());
1741 conn.exec_drop(
1742 &insert_stmt,
1743 (
1744 from_value::<String>(row1.0.clone()),
1745 from_value::<i32>(row1.1.clone()),
1746 from_value::<u32>(row1.2.clone()),
1747 from_value::<time::PrimitiveDateTime>(row1.3.clone()),
1748 from_value::<f32>(row1.4.clone()),
1749 ),
1750 )
1751 .unwrap();
1752 conn.exec_drop(
1753 &insert_stmt,
1754 (
1755 from_value::<String>(row2.0.clone()),
1756 row2.1.clone(),
1757 row2.2.clone(),
1758 row2.3.clone(),
1759 from_value::<f32>(row2.4.clone()),
1760 ),
1761 )
1762 .unwrap();
1763
1764 let select_stmt = conn.prep("SELECT * from mysql.tbl").unwrap();
1765 let rows: Vec<RowType> = conn.exec(&select_stmt, ()).unwrap();
1766
1767 assert_eq!(rows, vec![row1, row2]);
1768 }
1769
1770 #[test]
1771 fn should_parse_large_binary_result() {
1772 let mut conn = Conn::new(get_opts()).unwrap();
1773 let stmt = conn.prep("SELECT REPEAT('A', 20000000)").unwrap();
1774 let value: Value = conn.exec_first(&stmt, ()).unwrap().unwrap();
1775 assert_eq!(
1776 value,
1777 Bytes(std::iter::repeat_n(b'A', 20_000_000).collect())
1778 );
1779 }
1780
1781 #[test]
1782 fn manually_closed_stmt() {
1783 let opts = get_opts().stmt_cache_size(1);
1784 let mut conn = Conn::new(opts).unwrap();
1785 let stmt = conn.prep("SELECT 1").unwrap();
1786 conn.exec_drop(&stmt, ()).unwrap();
1787 conn.close(stmt).unwrap();
1788 let stmt = conn.prep("SELECT 1").unwrap();
1789 conn.exec_drop(&stmt, ()).unwrap();
1790 conn.close(stmt).unwrap();
1791 let stmt = conn.prep("SELECT 2").unwrap();
1792 conn.exec_drop(&stmt, ()).unwrap();
1793 }
1794
1795 #[test]
1796 fn should_start_commit_and_rollback_transactions() {
1797 let mut conn = Conn::new(get_opts()).unwrap();
1798 conn.query_drop(
1799 "CREATE TEMPORARY TABLE mysql.tbl(id INT NOT NULL PRIMARY KEY AUTO_INCREMENT, a INT)",
1800 )
1801 .unwrap();
1802 conn.start_transaction(TxOpts::default())
1803 .map(|mut t| {
1804 t.query_drop("INSERT INTO mysql.tbl(a) VALUES(1)").unwrap();
1805 assert_eq!(t.last_insert_id(), Some(1));
1806 assert_eq!(t.affected_rows(), 1);
1807 t.query_drop("INSERT INTO mysql.tbl(a) VALUES(2)").unwrap();
1808 t.commit().unwrap();
1809 })
1810 .unwrap();
1811 assert_eq!(
1812 conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1813 .unwrap()
1814 .next()
1815 .unwrap()
1816 .unwrap()
1817 .unwrap(),
1818 vec![Bytes(b"2".to_vec())]
1819 );
1820 conn.start_transaction(TxOpts::default())
1821 .map(|mut t| {
1822 t.query_drop("INSERT INTO tbl2(a) VALUES(1)").unwrap_err();
1823 })
1825 .unwrap();
1826 assert_eq!(
1827 conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1828 .unwrap()
1829 .next()
1830 .unwrap()
1831 .unwrap()
1832 .unwrap(),
1833 vec![Bytes(b"2".to_vec())]
1834 );
1835 conn.start_transaction(TxOpts::default())
1836 .map(|mut t| {
1837 t.query_drop("INSERT INTO mysql.tbl(a) VALUES(1)").unwrap();
1838 t.query_drop("INSERT INTO mysql.tbl(a) VALUES(2)").unwrap();
1839 t.rollback().unwrap();
1840 })
1841 .unwrap();
1842 assert_eq!(
1843 conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1844 .unwrap()
1845 .next()
1846 .unwrap()
1847 .unwrap()
1848 .unwrap(),
1849 vec![Bytes(b"2".to_vec())]
1850 );
1851 let mut tx = conn.start_transaction(TxOpts::default()).unwrap();
1852 tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (3,))
1853 .unwrap();
1854 tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (4,))
1855 .unwrap();
1856 tx.commit().unwrap();
1857 assert_eq!(
1858 conn.query_iter("SELECT COUNT(a) from mysql.tbl")
1859 .unwrap()
1860 .next()
1861 .unwrap()
1862 .unwrap()
1863 .unwrap(),
1864 vec![Bytes(b"4".to_vec())]
1865 );
1866 let mut tx = conn.start_transaction(TxOpts::default()).unwrap();
1867 tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (5,))
1868 .unwrap();
1869 tx.exec_drop("INSERT INTO mysql.tbl(a) VALUES(?)", (6,))
1870 .unwrap();
1871 drop(tx);
1872 assert_eq!(
1873 conn.query_first("SELECT COUNT(a) from mysql.tbl").unwrap(),
1874 Some(4_usize),
1875 );
1876 }
1877 #[test]
1878 fn should_handle_LOCAL_INFILE_with_custom_handler() {
1879 let mut conn = Conn::new(get_opts()).unwrap();
1880 conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a TEXT)")
1881 .unwrap();
1882 conn.set_local_infile_handler(Some(LocalInfileHandler::new(|_, stream| {
1883 let mut cell_data = vec![b'Z'; 65535];
1884 cell_data.push(b'\n');
1885 for _ in 0..1536 {
1886 stream.write_all(&cell_data)?;
1887 }
1888 Ok(())
1889 })));
1890 match conn.query_drop("LOAD DATA LOCAL INFILE 'file_name' INTO TABLE mysql.tbl") {
1891 Ok(_) => {}
1892 Err(ref err) if format!("{}", err).find("not allowed").is_some() => {
1893 return;
1894 }
1895 Err(err) => panic!("ERROR {}", err),
1896 }
1897 let count = conn
1898 .query_iter("SELECT * FROM mysql.tbl")
1899 .unwrap()
1900 .map(|row| {
1901 assert_eq!(from_row::<(Vec<u8>,)>(row.unwrap()).0.len(), 65535);
1902 1
1903 })
1904 .sum::<usize>();
1905 assert_eq!(count, 1536);
1906 }
1907
1908 #[test]
1909 fn should_reset_connection() {
1910 let mut conn = Conn::new(get_opts()).unwrap();
1911 conn.query_drop(
1912 "CREATE TEMPORARY TABLE `mysql`.`test` \
1913 (`test` VARCHAR(255) NULL);",
1914 )
1915 .unwrap();
1916 conn.query_drop("INSERT INTO `mysql`.`test` (`test`) VALUES ('foo');")
1917 .unwrap();
1918 assert_eq!(conn.affected_rows(), 1);
1919 conn.reset().unwrap();
1920 assert_eq!(conn.affected_rows(), 0);
1921 conn.query_drop("SELECT * FROM `mysql`.`test`;")
1922 .unwrap_err();
1923 }
1924
1925 #[test]
1926 fn should_change_user() -> crate::Result<()> {
1927 type ShouldRunFn = fn(bool, (u16, u16, u16)) -> bool;
1929 type CreateUserFn = fn(bool, (u16, u16, u16), &str) -> Vec<String>;
1931
1932 #[allow(clippy::type_complexity)]
1933 const TEST_MATRIX: [(&str, ShouldRunFn, CreateUserFn); 4] = [
1934 (
1935 "mysql_old_password",
1936 |is_mariadb, version| is_mariadb || version < (5, 7, 0),
1937 |is_mariadb, version, pass| {
1938 if is_mariadb {
1939 vec![
1940 "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password"
1941 .into(),
1942 "SET old_passwords=1".into(),
1943 format!("ALTER USER '__mats'@'%' IDENTIFIED BY '{pass}'"),
1944 "SET old_passwords=0".into(),
1945 ]
1946 } else if matches!(version, (5, 6, _)) {
1947 vec![
1948 "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password"
1949 .into(),
1950 format!("SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD('{pass}')"),
1951 ]
1952 } else {
1953 vec![
1954 "CREATE USER '__mats'@'%'".into(),
1955 format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"),
1956 ]
1957 }
1958 },
1959 ),
1960 (
1961 "mysql_native_password",
1962 |is_mariadb, version| is_mariadb || version < (8, 4, 0),
1963 |is_mariadb, version, pass| {
1964 if is_mariadb {
1965 vec![
1966 format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password AS PASSWORD('{pass}')")
1967 ]
1968 } else if version < (8, 0, 0) {
1969 vec![
1970 "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password"
1971 .into(),
1972 "SET old_passwords = 0".into(),
1973 format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"),
1974 ]
1975 } else {
1976 vec![
1977 format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password BY '{pass}'")
1978 ]
1979 }
1980 },
1981 ),
1982 (
1983 "caching_sha2_password",
1984 |is_mariadb, version| !is_mariadb && version >= (5, 8, 0),
1985 |_is_mariadb, _version, pass| {
1986 vec![
1987 format!("CREATE USER '__mats'@'%' IDENTIFIED WITH caching_sha2_password BY '{pass}'")
1988 ]
1989 },
1990 ),
1991 (
1992 "client_ed25519",
1993 |is_mariadb, version| is_mariadb && version >= (10, 4, 0),
1994 |_is_mariadb, _version, pass| {
1995 vec![
1996 format!("CREATE USER '__mats'@'%' IDENTIFIED WITH ed25519 AS PASSWORD('{pass}')")
1997 ]
1998 },
1999 ),
2000 ];
2001
2002 fn random_pass() -> String {
2003 let mut rng = rand::thread_rng();
2004 let mut pass = [0u8; 10];
2005 pass.try_fill(&mut rng).unwrap();
2006 IntoIterator::into_iter(pass)
2007 .map(|x| ((x % (123 - 97)) + 97) as char)
2008 .collect()
2009 }
2010
2011 let mut conn = Conn::new(get_opts()).unwrap();
2012
2013 assert_eq!(
2014 conn.query_first::<Value, _>("SELECT @foo")
2015 .unwrap()
2016 .unwrap(),
2017 Value::NULL
2018 );
2019
2020 conn.query_drop("SET @foo = 'foo'").unwrap();
2021
2022 assert_eq!(
2023 conn.query_first::<String, _>("SELECT @foo")
2024 .unwrap()
2025 .unwrap(),
2026 "foo",
2027 );
2028
2029 conn.change_user(Default::default()).unwrap();
2030 assert_eq!(
2031 conn.query_first::<Value, _>("SELECT @foo")
2032 .unwrap()
2033 .unwrap(),
2034 Value::NULL
2035 );
2036
2037 for (plugin, should_run, create_statements) in TEST_MATRIX {
2038 dbg!(plugin);
2039 let is_mariadb = conn.0.mariadb_server_version.is_some();
2040 let version = conn.server_version();
2041
2042 if should_run(is_mariadb, version) {
2043 let pass = random_pass();
2044
2045 let statement =
2047 "DROP USER /*!50700 IF EXISTS */ /*M!50700 IF EXISTS */ '__mats'";
2048 _ = conn.query_drop(dbg!(statement));
2050
2051 for statement in create_statements(is_mariadb, version, &pass) {
2052 conn.query_drop(dbg!(statement)).unwrap();
2053 }
2054
2055 let mut conn2 = Conn::new(get_opts().secure_auth(false)).unwrap();
2056 conn2
2057 .change_user(
2058 crate::ChangeUserOpts::default()
2059 .with_db_name(None)
2060 .with_user(Some("__mats".into()))
2061 .with_pass(Some(pass)),
2062 )
2063 .unwrap();
2064
2065 let (db, user) = conn2
2066 .query_first::<(Option<String>, String), _>("SELECT DATABASE(), USER();")
2067 .unwrap()
2068 .unwrap();
2069 assert_eq!(db, None);
2070 assert!(user.starts_with("__mats"));
2071 }
2072 }
2073
2074 Ok(())
2075 }
2076
2077 #[test]
2078 fn prep_exec() {
2079 let mut conn = Conn::new(get_opts()).unwrap();
2080
2081 let stmt1 = conn.prep("SELECT :foo").unwrap();
2082 let stmt2 = conn.prep("SELECT :bar").unwrap();
2083 assert_eq!(
2084 conn.exec::<String, _, _>(&stmt1, params! { "foo" => "foo" })
2085 .unwrap(),
2086 vec![String::from("foo")],
2087 );
2088 assert_eq!(
2089 conn.exec::<String, _, _>(&stmt2, params! { "bar" => "bar" })
2090 .unwrap(),
2091 vec![String::from("bar")],
2092 );
2093 }
2094
2095 #[test]
2096 fn should_connect_via_socket_for_127_0_0_1() {
2097 let opts = OptsBuilder::from_opts(get_opts());
2098 let mut conn = Conn::new(opts).unwrap();
2099 if conn.is_insecure() {
2100 assert!(
2101 conn.is_socket(),
2102 "Did not reconnect via socket {:?}",
2103 (
2104 conn.0.opts.get_prefer_socket(),
2105 conn.0.opts.addr_is_loopback(),
2106 conn.can_improved().and_then(|opts| {
2107 opts.map(|opts| {
2108 let mut new = crate::conn::Conn(Box::new(ConnInner::empty(opts)));
2109 new.connect_stream().and_then(|_| {
2110 new.connect()?;
2111 Ok(new)
2112 })
2113 })
2114 .transpose()
2115 }),
2116 )
2117 );
2118 }
2119 }
2120
2121 #[test]
2122 fn should_connect_via_socket_localhost() {
2123 let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost"));
2124 let mut conn = Conn::new(opts).unwrap();
2125 if conn.is_insecure() {
2126 assert!(
2127 conn.is_socket(),
2128 "Did not reconnect via socket {:?}",
2129 (
2130 conn.0.opts.get_prefer_socket(),
2131 conn.0.opts.addr_is_loopback(),
2132 conn.can_improved().and_then(|opts| {
2133 opts.map(|opts| {
2134 let mut new = crate::conn::Conn(Box::new(ConnInner::empty(opts)));
2135 new.connect_stream().and_then(|_| {
2136 new.connect()?;
2137 Ok(new)
2138 })
2139 })
2140 .transpose()
2141 }),
2142 )
2143 );
2144 }
2145 }
2146
2147 #[test]
2151 fn issue_306() {
2152 let (tx, rx) = channel::<()>();
2153 let handle = spawn(move || {
2154 let mut c1 = Conn::new(get_opts()).unwrap();
2155 let c1_id = c1.connection_id();
2156 let mut c2 = Conn::new(get_opts()).unwrap();
2157 let query_result = c1.query_iter("DO 1; SELECT SLEEP(1); DO 2;").unwrap();
2158 c2.query_drop(format!("KILL {c1_id}")).unwrap();
2159 drop(c2);
2160 drop(query_result);
2161 tx.send(()).unwrap();
2162 });
2163 std::thread::sleep(Duration::from_secs(2));
2164 assert!(rx.try_recv().is_ok());
2165 handle.join().unwrap();
2166 }
2167
2168 #[test]
2169 fn reset_does_work() {
2170 let mut c = Conn::new(get_opts()).unwrap();
2171 let cid = c.connection_id();
2172 c.query_drop("SET @foo = 'foo'").unwrap();
2173 assert_eq!(
2174 c.query_first::<String, _>("SELECT @foo").unwrap().unwrap(),
2175 "foo",
2176 );
2177 c.reset().unwrap();
2178 assert_eq!(cid, c.connection_id());
2179 assert_eq!(
2180 c.query_first::<Value, _>("SELECT @foo").unwrap().unwrap(),
2181 Value::NULL
2182 );
2183 }
2184
2185 #[test]
2186 fn should_drop_multi_result_set() {
2187 let opts = OptsBuilder::from_opts(get_opts()).db_name(Some("mysql"));
2188 let mut conn = Conn::new(opts).unwrap();
2189 conn.query_drop("CREATE TEMPORARY TABLE TEST_TABLE ( name varchar(255) )")
2190 .unwrap();
2191 conn.exec_drop("SELECT * FROM TEST_TABLE", ()).unwrap();
2192 conn.query_drop(
2193 r"
2194 INSERT INTO TEST_TABLE (name) VALUES ('one');
2195 INSERT INTO TEST_TABLE (name) VALUES ('two');
2196 INSERT INTO TEST_TABLE (name) VALUES ('three');",
2197 )
2198 .unwrap();
2199 conn.exec_drop("SELECT * FROM TEST_TABLE", ()).unwrap();
2200
2201 let mut query_result = conn
2202 .query_iter(
2203 r"
2204 SELECT * FROM TEST_TABLE;
2205 INSERT INTO TEST_TABLE (name) VALUES ('one');
2206 DO 0;",
2207 )
2208 .unwrap();
2209
2210 while let Some(result) = query_result.iter() {
2211 result.affected_rows();
2212 }
2213 }
2214
2215 #[test]
2216 fn should_handle_multi_result_set() {
2217 let opts = OptsBuilder::from_opts(get_opts())
2218 .prefer_socket(false)
2219 .db_name(Some("mysql"));
2220 let mut conn = Conn::new(opts).unwrap();
2221 conn.query_drop("DROP PROCEDURE IF EXISTS multi").unwrap();
2222 conn.query_drop(
2223 r#"CREATE PROCEDURE multi() BEGIN
2224 SELECT 1 UNION ALL SELECT 2;
2225 DO 1;
2226 SELECT 3 UNION ALL SELECT 4;
2227 DO 1;
2228 DO 1;
2229 SELECT REPEAT('A', 17000000);
2230 SELECT REPEAT('A', 17000000);
2231 END"#,
2232 )
2233 .unwrap();
2234 {
2235 let mut query_result = conn.query_iter("CALL multi()").unwrap();
2236 let result_set = query_result
2237 .by_ref()
2238 .map(|row| row.unwrap().unwrap().pop().unwrap())
2239 .collect::<Vec<crate::Value>>();
2240 assert_eq!(result_set, vec![Bytes(b"1".to_vec()), Bytes(b"2".to_vec())]);
2241 let result_set = query_result
2242 .by_ref()
2243 .map(|row| row.unwrap().unwrap().pop().unwrap())
2244 .collect::<Vec<crate::Value>>();
2245 assert_eq!(result_set, vec![Bytes(b"3".to_vec()), Bytes(b"4".to_vec())]);
2246 }
2247 let mut result = conn.query_iter("SELECT 1; SELECT 2; SELECT 3;").unwrap();
2248 let mut i = 0;
2249 while let Some(result_set) = result.iter() {
2250 i += 1;
2251 for row in result_set {
2252 match i {
2253 1 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"1".to_vec())]),
2254 2 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"2".to_vec())]),
2255 3 => assert_eq!(row.unwrap().unwrap(), vec![Bytes(b"3".to_vec())]),
2256 _ => unreachable!(),
2257 }
2258 }
2259 }
2260 assert_eq!(i, 3);
2261 }
2262
2263 #[test]
2264 fn issue_273() {
2265 let opts = OptsBuilder::from_opts(get_opts()).prefer_socket(false);
2266 let mut conn = Conn::new(opts).unwrap();
2267
2268 "DROP FUNCTION IF EXISTS f1".run(&mut conn).unwrap();
2269 r"CREATE DEFINER=`root`@`localhost` FUNCTION `f1`(p_arg INT, p_arg2 INT) RETURNS int
2270 DETERMINISTIC
2271 BEGIN
2272 RETURN p_arg + p_arg2;
2273 END"
2274 .run(&mut conn)
2275 .unwrap();
2276
2277 "SELECT f1(?, ?)"
2278 .with((100u8, 100u8))
2279 .run(&mut conn)
2280 .unwrap();
2281 }
2282
2283 #[test]
2284 fn issue_285() {
2285 let (tx, rx) = sync_channel::<()>(0);
2286
2287 let handle = std::thread::spawn(move || {
2288 let mut conn = Conn::new(get_opts()).unwrap();
2289 const INVALID_SQL: &str = r#"
2290 CREATE TEMPORARY TABLE IF NOT EXISTS `user_details` (
2291 `user_id` int(11) NOT NULL AUTO_INCREMENT,
2292 `username` varchar(255) DEFAULT NULL,
2293 `first_name` varchar(50) DEFAULT NULL,
2294 `last_name` varchar(50) DEFAULT NULL,
2295 PRIMARY KEY (`user_id`)
2296 );
2297
2298 INSERT INTO `user_details` (`user_id`, `username`, `first_name`, `last_name`)
2299 VALUES (1, 'rogers63', 'david')
2300 "#;
2301
2302 conn.query_iter(INVALID_SQL).unwrap();
2303 tx.send(()).unwrap();
2304 });
2305
2306 match rx.recv_timeout(Duration::from_secs(100_000)) {
2307 Ok(_) => handle.join().unwrap(),
2308 Err(_) => panic!("test failed"),
2309 }
2310 }
2311
2312 #[test]
2313 fn should_work_with_named_params() {
2314 let mut conn = Conn::new(get_opts()).unwrap();
2315 {
2316 let stmt = conn.prep("SELECT :a, :b, :a, :c").unwrap();
2317 let result = conn
2318 .exec_first(&stmt, params! {"a" => 1, "b" => 2, "c" => 3})
2319 .unwrap()
2320 .unwrap();
2321 assert_eq!((1_u8, 2_u8, 1_u8, 3_u8), result);
2322 }
2323
2324 let result = conn
2325 .exec_first(
2326 "SELECT :a, :b, :a + :b, :c",
2327 params! {
2328 "a" => 1,
2329 "b" => 2,
2330 "c" => 3,
2331 },
2332 )
2333 .unwrap()
2334 .unwrap();
2335 assert_eq!((1_u8, 2_u8, 3_u8, 3_u8), result);
2336 }
2337
2338 #[test]
2339 fn should_return_error_on_missing_named_parameter() {
2340 let mut conn = Conn::new(get_opts()).unwrap();
2341 let stmt = conn.prep("SELECT :a, :b, :a, :c, :d").unwrap();
2342 let result =
2343 conn.exec_first::<crate::Row, _, _>(&stmt, params! {"a" => 1, "b" => 2, "c" => 3,});
2344 match result {
2345 Err(DriverError(MissingNamedParameter(ref x))) if x == "d" => (),
2346 _ => panic!("MissingNamedParameter error expected"),
2347 }
2348 }
2349
2350 #[test]
2351 fn should_return_error_on_named_params_for_positional_statement() {
2352 let mut conn = Conn::new(get_opts()).unwrap();
2353 let stmt = conn.prep("SELECT ?, ?, ?, ?, ?").unwrap();
2354 let result = conn.exec_drop(&stmt, params! {"a" => 1, "b" => 2, "c" => 3,});
2355 match result {
2356 Err(DriverError(NamedParamsForPositionalQuery)) => (),
2357 _ => panic!("NamedParamsForPositionalQuery error expected"),
2358 }
2359 }
2360
2361 #[test]
2362 fn should_handle_tcp_connect_timeout() {
2363 use crate::error::{DriverError::ConnectTimeout, Error::DriverError};
2364
2365 let opts = OptsBuilder::from_opts(get_opts())
2366 .prefer_socket(false)
2367 .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)));
2368 assert!(Conn::new(opts).unwrap().ping().is_ok());
2369
2370 let opts = OptsBuilder::from_opts(get_opts())
2371 .prefer_socket(false)
2372 .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)))
2373 .ip_or_hostname(Some("192.168.255.255"));
2374 match Conn::new(opts).unwrap_err() {
2375 DriverError(ConnectTimeout) => {}
2376 err => panic!("Unexpected error: {}", err),
2377 }
2378 }
2379
2380 #[test]
2381 fn should_set_additional_capabilities() {
2382 use crate::consts::CapabilityFlags;
2383
2384 let opts = OptsBuilder::from_opts(get_opts())
2385 .additional_capabilities(CapabilityFlags::CLIENT_FOUND_ROWS);
2386
2387 let mut conn = Conn::new(opts).unwrap();
2388 conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl (a INT, b TEXT)")
2389 .unwrap();
2390 conn.query_drop("INSERT INTO mysql.tbl (a, b) VALUES (1, 'foo')")
2391 .unwrap();
2392 let result = conn
2393 .query_iter("UPDATE mysql.tbl SET b = 'foo' WHERE a = 1")
2394 .unwrap();
2395 assert_eq!(result.affected_rows(), 1);
2396 }
2397
2398 #[test]
2399 fn should_bind_before_connect() {
2400 let port = 28000 + (rand::random::<u16>() % 2000);
2401 let opts = OptsBuilder::from_opts(get_opts())
2402 .prefer_socket(false)
2403 .ip_or_hostname(Some("localhost"))
2404 .bind_address(Some(([127, 0, 0, 1], port)));
2405 let conn = Conn::new(opts).unwrap();
2406 let debug_format: String = format!("{:?}", conn);
2407 let expected_1 = format!("addr: V4(127.0.0.1:{})", port);
2408 let expected_2 = format!("addr: 127.0.0.1:{}", port);
2409 assert!(
2410 debug_format.contains(&expected_1) || debug_format.contains(&expected_2),
2411 "debug_format: {}",
2412 debug_format
2413 );
2414 }
2415
2416 #[test]
2417 fn should_bind_before_connect_with_timeout() {
2418 let port = 30000 + (rand::random::<u16>() % 2000);
2419 let opts = OptsBuilder::from_opts(get_opts())
2420 .prefer_socket(false)
2421 .ip_or_hostname(Some("localhost"))
2422 .bind_address(Some(([127, 0, 0, 1], port)))
2423 .tcp_connect_timeout(Some(::std::time::Duration::from_millis(1000)));
2424 let mut conn = Conn::new(opts).unwrap();
2425 assert!(conn.ping().is_ok());
2426 let debug_format: String = format!("{:?}", conn);
2427 let expected_1 = format!("addr: V4(127.0.0.1:{})", port);
2428 let expected_2 = format!("addr: 127.0.0.1:{}", port);
2429 assert!(
2430 debug_format.contains(&expected_1) || debug_format.contains(&expected_2),
2431 "debug_format: {}",
2432 debug_format
2433 );
2434 }
2435
2436 #[test]
2437 fn should_not_cache_statements_if_stmt_cache_size_is_zero() {
2438 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0);
2439 let mut conn = Conn::new(opts).unwrap();
2440
2441 let stmt1 = conn.prep("DO 1").unwrap();
2442 let stmt2 = conn.prep("DO 2").unwrap();
2443 let stmt3 = conn.prep("DO 3").unwrap();
2444
2445 conn.close(stmt1).unwrap();
2446 conn.close(stmt2).unwrap();
2447 conn.close(stmt3).unwrap();
2448
2449 let status: (Value, u8) = conn
2450 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close';")
2451 .unwrap()
2452 .unwrap();
2453 assert_eq!(status.1, 3);
2454 }
2455
2456 #[test]
2457 fn should_hold_stmt_cache_size_bounds() {
2458 let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(3);
2459 let mut conn = Conn::new(opts).unwrap();
2460
2461 conn.prep("DO 1").unwrap();
2462 conn.prep("DO 2").unwrap();
2463 conn.prep("DO 3").unwrap();
2464 conn.prep("DO 1").unwrap();
2465 conn.prep("DO 4").unwrap();
2466 conn.prep("DO 3").unwrap();
2467 conn.prep("DO 5").unwrap();
2468 conn.prep("DO 6").unwrap();
2469
2470 let status: (String, usize) = conn
2471 .query_first("SHOW SESSION STATUS LIKE 'Com_stmt_close'")
2472 .unwrap()
2473 .unwrap();
2474
2475 assert_eq!(status.1, 3);
2476
2477 let mut order = conn
2478 .0
2479 .stmt_cache
2480 .iter()
2481 .map(|(_, entry)| &**entry.query.0.as_ref())
2482 .collect::<Vec<&[u8]>>();
2483 order.sort();
2484 assert_eq!(order, &[b"DO 3", b"DO 5", b"DO 6"]);
2485 }
2486
2487 #[test]
2488 fn should_handle_json_columns() {
2489 use crate::{Deserialized, Serialized};
2490 use serde::{Deserialize, Serialize};
2491 use serde_json::Value as Json;
2492 use std::str::FromStr;
2493
2494 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
2495 pub struct DecTest {
2496 foo: String,
2497 quux: (u64, String),
2498 }
2499
2500 let decodable = DecTest {
2501 foo: "bar".into(),
2502 quux: (42, "hello".into()),
2503 };
2504
2505 let mut conn = Conn::new(get_opts()).unwrap();
2506 if conn
2507 .query_drop("CREATE TEMPORARY TABLE mysql.tbl(a VARCHAR(32), b JSON)")
2508 .is_err()
2509 {
2510 conn.query_drop("CREATE TEMPORARY TABLE mysql.tbl(a VARCHAR(32), b TEXT)")
2511 .unwrap();
2512 }
2513 conn.exec_drop(
2514 r#"INSERT INTO mysql.tbl VALUES ('hello', ?)"#,
2515 (Serialized(&decodable),),
2516 )
2517 .unwrap();
2518
2519 let (a, b): (String, Json) = conn
2520 .query_first("SELECT a, b FROM mysql.tbl")
2521 .unwrap()
2522 .unwrap();
2523 assert_eq!(
2524 (a, b),
2525 (
2526 "hello".into(),
2527 Json::from_str(r#"{"foo": "bar", "quux": [42, "hello"]}"#).unwrap()
2528 )
2529 );
2530
2531 let row = conn
2532 .exec_first("SELECT a, b FROM mysql.tbl WHERE a = ?", ("hello",))
2533 .unwrap()
2534 .unwrap();
2535 let (a, Deserialized(b)) = from_row(row);
2536 assert_eq!((a, b), (String::from("hello"), decodable));
2537 }
2538
2539 #[test]
2540 fn should_set_connect_attrs() {
2541 let opts = OptsBuilder::from_opts(
2542 get_opts().connect_attrs::<String, String>(Some(Default::default())),
2543 );
2544 let mut conn = Conn::new(opts).unwrap();
2545
2546 let support_connect_attrs = match (conn.0.server_version, conn.0.mariadb_server_version)
2547 {
2548 (Some(ref version), _) if *version >= (5, 6, 0) => true,
2549 (_, Some(ref version)) if *version >= (10, 0, 0) => true,
2550 _ => false,
2551 };
2552
2553 if support_connect_attrs {
2554 if get_system_variable::<String>(&mut conn, "performance_schema") != "ON" {
2557 panic!("The system variable `performance_schema` is off. Restart the MySQL server with `--performance_schema=on` to pass the test.");
2558 }
2559 let attrs_size: i32 =
2560 get_system_variable(&mut conn, "performance_schema_session_connect_attrs_size");
2561 if (0..=128).contains(&attrs_size) {
2562 panic!("The system variable `performance_schema_session_connect_attrs_size` is {}. Restart the MySQL server with `--performance_schema_session_connect_attrs_size=-1` to pass the test.", attrs_size);
2563 }
2564
2565 fn assert_connect_attrs(conn: &mut Conn, expected_values: &[(&str, &str)]) {
2566 let mut actual_values = HashMap::new();
2567 for row in conn.query_iter("SELECT attr_name, attr_value FROM performance_schema.session_account_connect_attrs WHERE processlist_id = connection_id()").unwrap() {
2568 let (name, value) = from_row::<(String, String)>(row.unwrap());
2569 actual_values.insert(name, value);
2570 }
2571
2572 for (name, value) in expected_values {
2573 assert_eq!(actual_values.get(*name), Some(&value.to_string()));
2574 }
2575 }
2576
2577 let pid = process::id().to_string();
2578 let prog_name = std::env::args_os()
2579 .next()
2580 .unwrap()
2581 .to_string_lossy()
2582 .into_owned();
2583 let mut expected_values = vec![
2584 ("_client_name", "rust-mysql-simple"),
2585 ("_client_version", env!("CARGO_PKG_VERSION")),
2586 ("_os", env!("CARGO_CFG_TARGET_OS")),
2587 ("_pid", &pid),
2588 ("_platform", env!("CARGO_CFG_TARGET_ARCH")),
2589 ("program_name", &prog_name),
2590 ];
2591
2592 assert_connect_attrs(&mut conn, &expected_values);
2594
2595 let opts = OptsBuilder::from_opts(get_opts());
2597 let mut connect_attrs = HashMap::with_capacity(3);
2598 connect_attrs.insert("foo", "foo val");
2599 connect_attrs.insert("bar", "bar val");
2600 connect_attrs.insert("program_name", "my program name");
2601 let mut conn = Conn::new(opts.connect_attrs(Some(connect_attrs))).unwrap();
2602 expected_values.pop(); expected_values.push(("foo", "foo val"));
2604 expected_values.push(("bar", "bar val"));
2605 expected_values.push(("program_name", "my program name"));
2606 assert_connect_attrs(&mut conn, &expected_values);
2607 }
2608 }
2609
2610 #[test]
2614 fn test_metadata_caching() {
2615 use crate::consts::ColumnType;
2616 let mut conn = Conn::new(get_opts()).unwrap();
2617 if conn.0.mariadb_server_version.is_none() {
2618 return;
2619 }
2620
2621 conn.query_drop(
2622 r"CREATE TEMPORARY TABLE t_metadata_caching (
2623 id INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
2624 val VARCHAR(32) NOT NULL)",
2625 )
2626 .unwrap();
2627
2628 let insert_stmt = conn
2630 .prep("INSERT INTO t_metadata_caching (val) VALUES (?)")
2631 .unwrap();
2632 let _ = conn.exec_drop(&insert_stmt, ("AAA",));
2633 let _ = conn.exec_drop(&insert_stmt, ("BB",));
2634 let mut ps = conn.prep("SELECT id, val FROM t_metadata_caching").unwrap();
2635
2636 let mut columns_from_prep = ps.columns();
2637 let mut metadata_from_prep: Vec<(String, ColumnType)> = columns_from_prep
2638 .iter()
2639 .map(|column| (column.name_str().to_string(), column.column_type()))
2640 .collect();
2641
2642 let mut query_result = conn.exec_iter(&ps, ()).unwrap();
2643 let mut columns_from_exec1 = query_result.columns();
2644 let mut metadata_from_exec1: Vec<(String, ColumnType)> = columns_from_exec1
2645 .as_ref()
2646 .iter()
2647 .map(|column| (column.name_str().to_string(), column.column_type()))
2648 .collect();
2649
2650 assert_eq!(metadata_from_prep, metadata_from_exec1);
2652 assert_eq!(metadata_from_prep.len(), 2);
2653 assert_eq!(metadata_from_prep[0].0, "id");
2654 assert_eq!(metadata_from_prep[0].1, ColumnType::MYSQL_TYPE_LONG);
2655 assert_eq!(metadata_from_prep[1].0, "val");
2656 assert_eq!(metadata_from_prep[1].1, ColumnType::MYSQL_TYPE_VAR_STRING);
2657
2658 let fetched_rows: Vec<(i32, String)> = query_result
2659 .map(|row_result| crate::from_row(row_result.unwrap()))
2660 .collect();
2661
2662 let expected_rows = [(1, "AAA".to_string()), (2, "BB".to_string())];
2663 assert_eq!(fetched_rows.len(), expected_rows.len());
2664
2665 for (fetched, expected) in fetched_rows.iter().zip(expected_rows.iter()) {
2666 assert_eq!(fetched, expected);
2667 }
2668
2669 ps = conn
2672 .prep("SELECT val FROM t_metadata_caching WHERE id = ?")
2673 .unwrap();
2674 columns_from_prep = ps.columns();
2675 metadata_from_prep = columns_from_prep
2676 .iter()
2677 .map(|column| (column.name_str().to_string(), column.column_type()))
2678 .collect();
2679 let single_row: Option<String> = conn.exec_first(&ps, (1,)).unwrap();
2680 if let Some(val) = single_row {
2681 assert_eq!(metadata_from_prep.len(), 1);
2682 assert_eq!(metadata_from_prep[0].0, "val");
2683 assert_eq!(metadata_from_prep[0].1, ColumnType::MYSQL_TYPE_VAR_STRING);
2684
2685 assert_eq!(val, "AAA".to_string());
2686 }
2687 ps = conn.prep("SELECT ?").unwrap();
2689
2690 columns_from_prep = ps.columns();
2691 metadata_from_prep = columns_from_prep
2692 .iter()
2693 .map(|column| (column.name_str().to_string(), column.column_type()))
2694 .collect();
2695
2696 query_result = conn.exec_iter(&ps, (12,)).unwrap();
2698 columns_from_exec1 = query_result.columns();
2699 metadata_from_exec1 = columns_from_exec1
2700 .as_ref()
2701 .iter()
2702 .map(|column| (column.name_str().to_string(), column.column_type()))
2703 .collect();
2704 let fetched_rows: Vec<i32> = query_result
2705 .map(|row_result| crate::from_row(row_result.unwrap()))
2706 .collect();
2707
2708 query_result = conn.exec_iter(&ps, (42,)).unwrap();
2710 let columns_from_exec2 = query_result.columns();
2711 let metadata_from_exec2 = columns_from_exec2
2712 .as_ref()
2713 .iter()
2714 .map(|column| (column.name_str().to_string(), column.column_type()))
2715 .collect::<Vec<_>>();
2716 let fetched_rows2: Vec<i32> = query_result
2717 .map(|row_result| crate::from_row(row_result.unwrap()))
2718 .collect();
2719
2720 query_result = conn.exec_iter(&ps, ("foo",)).unwrap();
2722 let columns_from_exec3 = query_result.columns();
2723 let metadata_from_exec3 = columns_from_exec3
2724 .as_ref()
2725 .iter()
2726 .map(|column| (column.name_str().to_string(), column.column_type()))
2727 .collect::<Vec<_>>();
2728 let fetched_rows3: Vec<String> = query_result
2729 .map(|row_result| crate::from_row(row_result.unwrap()))
2730 .collect();
2731
2732 assert_eq!(metadata_from_exec1.len(), 1);
2734 assert_eq!(metadata_from_exec2.len(), 1);
2735 assert_eq!(metadata_from_exec3.len(), 1);
2736 assert_eq!(metadata_from_prep.len(), 1);
2737 assert_eq!(metadata_from_prep[0].0, "?");
2738 assert!(
2739 metadata_from_prep[0].1 == ColumnType::MYSQL_TYPE_NULL
2740 || metadata_from_prep[0].1 == ColumnType::MYSQL_TYPE_VAR_STRING,
2741 "Expected MYSQL_TYPE_NULL(MariaDB) or MYSQL_TYPE_VAR_STRING(MySQL), got {:?}",
2742 metadata_from_prep[0].1
2743 );
2744 assert_eq!(metadata_from_exec1[0].0, "?");
2745 assert_eq!(metadata_from_exec1[0].1, ColumnType::MYSQL_TYPE_LONGLONG);
2746 assert_eq!(metadata_from_exec2[0].0, "?");
2747 assert_eq!(metadata_from_exec2[0].1, ColumnType::MYSQL_TYPE_LONGLONG);
2748 assert_eq!(metadata_from_exec3[0].0, "?");
2749 assert_eq!(metadata_from_exec3[0].1, ColumnType::MYSQL_TYPE_VAR_STRING);
2750
2751 assert_eq!(fetched_rows[0], 12);
2752 assert_eq!(fetched_rows2[0], 42);
2753 assert_eq!(fetched_rows3[0], "foo".to_owned());
2754 }
2755
2756 #[test]
2757 #[cfg(feature = "binlog")]
2758 fn should_read_binlog() -> crate::Result<()> {
2759 use std::{
2760 collections::HashMap, sync::mpsc::sync_channel, thread::spawn, time::Duration,
2761 };
2762
2763 fn gen_dummy_data() -> crate::Result<()> {
2764 let mut conn = Conn::new(get_opts())?;
2765
2766 "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)".run(&mut conn)?;
2767
2768 for i in 0_u8..100 {
2769 "INSERT INTO customers(customer_id) VALUES (?)"
2770 .with((i,))
2771 .run(&mut conn)?;
2772 }
2773
2774 "DROP TABLE customers".run(&mut conn)?;
2775
2776 Ok(())
2777 }
2778
2779 fn get_conn() -> crate::Result<(Conn, Vec<u8>, u64)> {
2780 let mut conn = Conn::new(get_opts())?;
2781
2782 if let Ok(Some(gtid_mode)) =
2783 "SELECT @@GLOBAL.GTID_MODE".first::<String, _>(&mut conn)
2784 {
2785 if !gtid_mode.starts_with("ON") {
2786 panic!(
2787 "GTID_MODE is disabled \
2788 (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)"
2789 );
2790 }
2791 }
2792
2793 let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn)?.unwrap();
2794 let filename = row.get(0).unwrap();
2795 let position = row.get(1).unwrap();
2796
2797 gen_dummy_data().unwrap();
2798 Ok((conn, filename, position))
2799 }
2800
2801 let (conn, filename, pos) = get_conn().unwrap();
2803 let is_mariadb = conn.0.mariadb_server_version.is_some();
2804
2805 let binlog_stream = conn
2806 .get_binlog_stream(BinlogRequest::new(12).with_filename(filename).with_pos(pos))
2807 .unwrap();
2808
2809 let mut events_num = 0;
2810 let (tx, rx) = sync_channel(0);
2811 spawn(move || {
2812 for event in binlog_stream {
2813 tx.send(event).unwrap();
2814 }
2815 });
2816 let mut tmes = HashMap::new();
2817 while let Ok(event) = rx.recv_timeout(Duration::from_secs(1)) {
2818 let event = event.unwrap();
2819 events_num += 1;
2820
2821 event.header().event_type().unwrap();
2823
2824 match event.read_data()?.unwrap() {
2826 EventData::TableMapEvent(tme) => {
2827 tmes.insert(tme.table_id(), tme.into_owned());
2828 }
2829 EventData::RowsEvent(re) => {
2830 for row in re.rows(&tmes[&re.table_id()]) {
2831 row.unwrap();
2832 }
2833 }
2834 _ => (),
2835 }
2836 }
2837 assert!(events_num > 0);
2838
2839 if !is_mariadb {
2840 let (conn, filename, pos) = get_conn().unwrap();
2842
2843 let binlog_stream = conn
2844 .get_binlog_stream(
2845 BinlogRequest::new(13)
2846 .with_use_gtid(true)
2847 .with_filename(filename)
2848 .with_pos(pos),
2849 )
2850 .unwrap();
2851
2852 let mut events_num = 0;
2853 let (tx, rx) = sync_channel(0);
2854 spawn(move || {
2855 for event in binlog_stream {
2856 tx.send(event).unwrap();
2857 }
2858 });
2859 let mut tmes = HashMap::new();
2860 while let Ok(event) = rx.recv_timeout(Duration::from_secs(1)) {
2861 let event = event.unwrap();
2862 events_num += 1;
2863
2864 event.header().event_type().unwrap();
2866
2867 match event.read_data()?.unwrap() {
2869 EventData::TableMapEvent(tme) => {
2870 tmes.insert(tme.table_id(), tme.into_owned());
2871 }
2872 EventData::RowsEvent(re) => {
2873 for row in re.rows(&tmes[&re.table_id()]) {
2874 row.unwrap();
2875 }
2876 }
2877 _ => (),
2878 }
2879 }
2880 assert!(events_num > 0);
2881 }
2882
2883 let (conn, filename, pos) = get_conn().unwrap();
2885
2886 let binlog_stream = conn
2887 .get_binlog_stream(
2888 BinlogRequest::new(14)
2889 .with_filename(filename)
2890 .with_pos(pos)
2891 .with_flags(crate::BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK),
2892 )
2893 .unwrap();
2894
2895 events_num = 0;
2896 for event in binlog_stream {
2897 let event = event.unwrap();
2898 events_num += 1;
2899 event.header().event_type().unwrap();
2900 event.read_data()?;
2901 }
2902 assert!(events_num > 0);
2903
2904 Ok(())
2905 }
2906 }
2907
2908 #[cfg(feature = "nightly")]
2909 mod bench {
2910 use test;
2911
2912 use crate::{params, prelude::*, test_misc::get_opts, Conn, Value::NULL};
2913
2914 #[bench]
2915 fn simple_exec(bencher: &mut test::Bencher) {
2916 let mut conn = Conn::new(get_opts()).unwrap();
2917 bencher.iter(|| {
2918 let _ = conn.query_drop("DO 1");
2919 })
2920 }
2921
2922 #[bench]
2923 fn prepared_exec(bencher: &mut test::Bencher) {
2924 let mut conn = Conn::new(get_opts()).unwrap();
2925 let stmt = conn.prep("DO 1").unwrap();
2926 bencher.iter(|| {
2927 let _ = conn.exec_drop(&stmt, ()).unwrap();
2928 })
2929 }
2930
2931 #[bench]
2932 fn prepare_and_exec(bencher: &mut test::Bencher) {
2933 let mut conn = Conn::new(get_opts()).unwrap();
2934 bencher.iter(|| {
2935 let stmt = conn.prep("SELECT ?").unwrap();
2936 let _ = conn.exec_drop(&stmt, (0,)).unwrap();
2937 })
2938 }
2939
2940 #[bench]
2941 fn simple_query_row(bencher: &mut test::Bencher) {
2942 let mut conn = Conn::new(get_opts()).unwrap();
2943 bencher.iter(|| {
2944 let _ = conn.query_drop("SELECT 1").unwrap();
2945 })
2946 }
2947
2948 #[bench]
2949 fn simple_prepared_query_row(bencher: &mut test::Bencher) {
2950 let mut conn = Conn::new(get_opts()).unwrap();
2951 let stmt = conn.prep("SELECT 1").unwrap();
2952 bencher.iter(|| {
2953 let _ = conn.exec_drop(&stmt, ()).unwrap();
2954 })
2955 }
2956
2957 #[bench]
2958 fn simple_prepared_query_row_with_param(bencher: &mut test::Bencher) {
2959 let mut conn = Conn::new(get_opts()).unwrap();
2960 let stmt = conn.prep("SELECT ?").unwrap();
2961 bencher.iter(|| {
2962 let _ = conn.exec_drop(&stmt, (0,)).unwrap();
2963 })
2964 }
2965
2966 #[bench]
2967 fn simple_prepared_query_row_with_named_param(bencher: &mut test::Bencher) {
2968 let mut conn = Conn::new(get_opts()).unwrap();
2969 let stmt = conn.prep("SELECT :a").unwrap();
2970 bencher.iter(|| {
2971 let _ = conn.exec_drop(&stmt, params! {"a" => 0}).unwrap();
2972 })
2973 }
2974
2975 #[bench]
2976 fn simple_prepared_query_row_with_5_params(bencher: &mut test::Bencher) {
2977 let mut conn = Conn::new(get_opts()).unwrap();
2978 let stmt = conn.prep("SELECT ?, ?, ?, ?, ?").unwrap();
2979 let params = (42i8, b"123456".to_vec(), 1.618f64, NULL, 1i8);
2980 bencher.iter(|| {
2981 let _ = conn.exec_drop(&stmt, ¶ms).unwrap();
2982 })
2983 }
2984
2985 #[bench]
2986 fn simple_prepared_query_row_with_5_named_params(bencher: &mut test::Bencher) {
2987 let mut conn = Conn::new(get_opts()).unwrap();
2988 let stmt = conn
2989 .prep("SELECT :one, :two, :three, :four, :five")
2990 .unwrap();
2991 bencher.iter(|| {
2992 let _ = conn.exec_drop(
2993 &stmt,
2994 params! {
2995 "one" => 42i8,
2996 "two" => b"123456",
2997 "three" => 1.618f64,
2998 "four" => NULL,
2999 "five" => 1i8,
3000 },
3001 );
3002 })
3003 }
3004
3005 #[bench]
3006 fn select_large_string(bencher: &mut test::Bencher) {
3007 let mut conn = Conn::new(get_opts()).unwrap();
3008 bencher.iter(|| {
3009 let _ = conn.query_drop("SELECT REPEAT('A', 10000)").unwrap();
3010 })
3011 }
3012
3013 #[bench]
3014 fn select_prepared_large_string(bencher: &mut test::Bencher) {
3015 let mut conn = Conn::new(get_opts()).unwrap();
3016 let stmt = conn.prep("SELECT REPEAT('A', 10000)").unwrap();
3017 bencher.iter(|| {
3018 let _ = conn.exec_drop(&stmt, ()).unwrap();
3019 })
3020 }
3021
3022 #[bench]
3023 fn many_small_rows(bencher: &mut test::Bencher) {
3024 let mut conn = Conn::new(get_opts()).unwrap();
3025 conn.query_drop("CREATE TEMPORARY TABLE mysql.x (id INT)")
3026 .unwrap();
3027 for _ in 0..512 {
3028 conn.query_drop("INSERT INTO mysql.x VALUES (256)").unwrap();
3029 }
3030 let stmt = conn.prep("SELECT * FROM mysql.x").unwrap();
3031 bencher.iter(|| {
3032 let _ = conn.exec_drop(&stmt, ()).unwrap();
3033 });
3034 }
3035 }
3036}