# Copyright (c) 2009, 2022, Oracle and/or its affiliates. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, as # published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, # as designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an # additional permission to link the program and your derivative works # with the separately licensed software that they have included with # MySQL. # # Without limiting anything contained in the foregoing, this file, # which is part of MySQL Connector/Python, is also subject to the # Universal FOSS Exception, version 1.0, a copy of which can be found at # http://oss.oracle.com/licenses/universal-foss-exception. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. # See the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """Implements the MySQL Client/Server protocol """ import struct import datetime from decimal import Decimal from .constants import ( FieldFlag, ServerCmd, FieldType, ClientFlag, PARAMETER_COUNT_AVAILABLE) from . import errors, utils from .authentication import get_auth_plugin from .errors import DatabaseError, get_exception PROTOCOL_VERSION = 10 class MySQLProtocol(object): """Implements MySQL client/server protocol Create and parses MySQL packets. """ def _connect_with_db(self, client_flags, database): """Prepare database string for handshake response""" if client_flags & ClientFlag.CONNECT_WITH_DB and database: return database.encode('utf8') + b'\x00' return b'\x00' def _auth_response(self, client_flags, username, password, database, auth_plugin, auth_data, ssl_enabled): """Prepare the authentication response""" if not password: return b'\x00' try: auth = get_auth_plugin(auth_plugin)( auth_data, username=username, password=password, database=database, ssl_enabled=ssl_enabled) plugin_auth_response = auth.auth_response() except (TypeError, errors.InterfaceError) as exc: raise errors.InterfaceError( "Failed authentication: {0}".format(str(exc))) if client_flags & ClientFlag.SECURE_CONNECTION: resplen = len(plugin_auth_response) auth_response = struct.pack('= 7: mcs = 0 if length == 11: mcs = struct.unpack(' 8: mcs = struct.unpack('= -128: format_ = '= -32768: format_ = '= -2147483648: format_ = ' 0: packed += utils.int4store(value.microsecond) packed = utils.int1store(len(packed)) + packed return (packed, field_type) def _prepare_binary_time(self, value): """Prepare a time object for the MySQL binary protocol This method prepares a time object of type datetime.timedelta or datetime.time for sending over the MySQL binary protocol. A tuple is returned with the prepared value and field type as elements. Raises ValueError when the argument value is of invalid type. Returns a tuple. """ if not isinstance(value, (datetime.timedelta, datetime.time)): raise ValueError( "Argument must a datetime.timedelta or datetime.time") field_type = FieldType.TIME negative = 0 mcs = None packed = b'' if isinstance(value, datetime.timedelta): if value.days < 0: negative = 1 (hours, remainder) = divmod(value.seconds, 3600) (mins, secs) = divmod(remainder, 60) packed += (utils.int4store(abs(value.days)) + utils.int1store(hours) + utils.int1store(mins) + utils.int1store(secs)) mcs = value.microseconds else: packed += (utils.int4store(0) + utils.int1store(value.hour) + utils.int1store(value.minute) + utils.int1store(value.second)) mcs = value.microsecond if mcs: packed += utils.int4store(mcs) packed = utils.int1store(negative) + packed packed = utils.int1store(len(packed)) + packed return (packed, field_type) def _prepare_stmt_send_long_data(self, statement, param, data): """Prepare long data for prepared statements Returns a string. """ packet = ( utils.int4store(statement) + utils.int2store(param) + data) return packet def make_stmt_execute(self, statement_id, data=(), parameters=(), flags=0, long_data_used=None, charset='utf8', query_attrs=None, converter_str_fallback=False): """Make a MySQL packet with the Statement Execute command""" iteration_count = 1 null_bitmap = [0] * ((len(data) + 7) // 8) values = [] types = [] packed = b'' data_len = len(data) query_attr_names = [] flags = flags if not query_attrs else flags + PARAMETER_COUNT_AVAILABLE if charset == 'utf8mb4': charset = 'utf8' if long_data_used is None: long_data_used = {} if query_attrs: data = list(data) for _, attr_val in query_attrs: data.append(attr_val) null_bitmap = [0] * ((len(data) + 7) // 8) if parameters or data: if data_len != len(parameters): raise errors.InterfaceError( "Failed executing prepared statement: data values does not" " match number of parameters") for pos, _ in enumerate(data): value = data[pos] _flags = 0 if value is None: null_bitmap[(pos // 8)] |= 1 << (pos % 8) types.append(utils.int1store(FieldType.NULL) + utils.int1store(_flags)) continue elif pos in long_data_used: if long_data_used[pos][0]: # We suppose binary data field_type = FieldType.BLOB else: # We suppose text data field_type = FieldType.STRING elif isinstance(value, int): (packed, field_type, _flags) = self._prepare_binary_integer(value) values.append(packed) elif isinstance(value, str): value = value.encode(charset) values.append(utils.lc_int(len(value)) + value) field_type = FieldType.VARCHAR elif isinstance(value, bytes): values.append(utils.lc_int(len(value)) + value) field_type = FieldType.BLOB elif isinstance(value, Decimal): values.append( utils.lc_int(len(str(value).encode( charset))) + str(value).encode(charset)) field_type = FieldType.DECIMAL elif isinstance(value, float): values.append(struct.pack(' data_len: name = query_attrs[pos - data_len][0].encode(charset) query_attr_names.append( utils.lc_int(len(name)) + name) packet = ( utils.int4store(statement_id) + utils.int1store(flags) + utils.int4store(iteration_count)) # if (num_params > 0 || (CLIENT_QUERY_ATTRIBUTES \ # && (flags & PARAMETER_COUNT_AVAILABLE)) { if query_attrs is not None: parameter_count = data_len + len(query_attrs) else: parameter_count = data_len if parameter_count: # if CLIENT_QUERY_ATTRIBUTES is on if query_attrs is not None: packet += utils.lc_int(parameter_count) packet += ( b''.join([struct.pack('B', bit) for bit in null_bitmap]) + utils.int1store(1)) count = 0 for a_type in types: packet += a_type # if CLIENT_QUERY_ATTRIBUTES is on { # string parameter_name Name of the parameter # or empty if not present # } if CLIENT_QUERY_ATTRIBUTES is on if query_attrs is not None: if count+1 > data_len: packet += query_attr_names[count - data_len] else: packet += b'\x00' count+=1 for a_value in values: packet += a_value return packet def parse_auth_switch_request(self, packet): """Parse a MySQL AuthSwitchRequest-packet""" if not packet[4] == 254: raise errors.InterfaceError( "Failed parsing AuthSwitchRequest packet") (packet, plugin_name) = utils.read_string(packet[5:], end=b'\x00') if packet and packet[-1] == 0: packet = packet[:-1] return plugin_name.decode('utf8'), packet def parse_auth_more_data(self, packet): """Parse a MySQL AuthMoreData-packet""" if not packet[4] == 1: raise errors.InterfaceError( "Failed parsing AuthMoreData packet") return packet[5:]