162 #include <sys/socket.h>
163 #include <netinet/in.h>
164 #include <arpa/inet.h>
175 #ifdef STRCASECMP_IN_STRINGS_H
233 #if defined(MYSQL_VERSION_ID) && MYSQL_VERSION_ID >= 40000
240 #if MYSQL_VERSION_ID >= 50023
244 int require_result_set;
247 #define STATACTIVE (1<<0)
248 #define STATFAIL (1<<1)
249 #define STATUNTRIED (1<<2)
251 #define TYPEUNIX (1<<0)
252 #define TYPEINET (1<<1)
254 #define RETRY_CONN_MAX 100
255 #define RETRY_CONN_INTV 60
256 #define IDLE_CONN_INTV 60
259 static PLMYSQL *plmysql_init(
ARGV *);
260 static int plmysql_query(DICT_MYSQL *,
const char *,
VSTRING *, MYSQL_RES **);
261 static void plmysql_dealloc(PLMYSQL *);
262 static void plmysql_close_host(HOST *);
263 static void plmysql_down_host(HOST *);
264 static void plmysql_connect_single(DICT_MYSQL *, HOST *);
265 static const char *dict_mysql_lookup(
DICT *,
const char *);
267 static void dict_mysql_close(
DICT *);
268 static void mysql_parse_config(DICT_MYSQL *,
const char *);
269 static HOST *host_init(
const char *);
273 static void dict_mysql_quote(
DICT *dict,
const char *name,
VSTRING *result)
275 DICT_MYSQL *dict_mysql = (DICT_MYSQL *) dict;
276 int len = strlen(name);
284 msg_panic(
"dict_mysql_quote: integer overflow in %lu+2*%d+1",
286 buflen = 2 * len + 1;
289 #if defined(MYSQL_VERSION_ID) && MYSQL_VERSION_ID >= 40000
290 if (dict_mysql->active_host)
291 mysql_real_escape_string(dict_mysql->active_host->db,
295 mysql_escape_string(
vstring_end(result), name, len);
302 static const char *dict_mysql_lookup(
DICT *dict,
const char *name)
304 const char *myname =
"dict_mysql_lookup";
305 DICT_MYSQL *dict_mysql = (DICT_MYSQL *) dict;
306 MYSQL_RES *query_res;
337 msg_info(
"%s: Skipping lookup of '%s'", myname, name);
341 msg_warn(
"%s:%s 'domain' pattern match failed for '%s'",
345 #define INIT_VSTR(buf, len) do { \
347 buf = vstring_alloc(len); \
348 VSTRING_RESET(buf); \
349 VSTRING_TERMINATE(buf); \
352 INIT_VSTR(query, 10);
362 #if defined(MYSQL_VERSION_ID) && MYSQL_VERSION_ID >= 40000
366 name, 0, query, quote_func))
370 if (plmysql_query(dict_mysql, name, query, &query_res) == 0) {
376 numrows = mysql_num_rows(query_res);
378 msg_info(
"%s: retrieved %d rows", myname, numrows);
380 mysql_free_result(query_res);
383 INIT_VSTR(result, 10);
385 for (expansion = i = 0; i < numrows && dict->
error == 0; i++) {
386 row = mysql_fetch_row(query_res);
387 for (j = 0; j < mysql_num_fields(query_res); j++) {
389 row[j], name, result, 0)
390 && dict_mysql->expansion_limit > 0
391 && ++expansion > dict_mysql->expansion_limit) {
392 msg_warn(
"%s: %s: Expansion limit exceeded for key: '%s'",
393 myname, dict_mysql->parser->name, name);
399 mysql_free_result(query_res);
401 return ((dict->
error == 0 && *r) ? r : 0);
406 static int dict_mysql_check_stat(HOST *host,
unsigned stat,
unsigned type,
409 if ((host->stat & stat) && (!type || host->type & type)) {
411 if (host->stat == STATFAIL && host->ts > 0 && host->ts >= t)
420 static HOST *dict_mysql_find_host(PLMYSQL *PLDB,
unsigned stat,
unsigned type)
427 t = time((time_t *) 0);
428 for (i = 0; i < PLDB->len_hosts; i++) {
429 if (dict_mysql_check_stat(PLDB->db_hosts[i], stat, type, t))
437 for (i = 0; i < PLDB->len_hosts; i++) {
438 if (dict_mysql_check_stat(PLDB->db_hosts[i], stat, type, t) &&
440 return PLDB->db_hosts[i];
448 static HOST *dict_mysql_get_active(DICT_MYSQL *dict_mysql)
450 const char *myname =
"dict_mysql_get_active";
451 PLMYSQL *PLDB = dict_mysql->pldb;
453 int count = RETRY_CONN_MAX;
456 if ((host = dict_mysql_find_host(PLDB, STATACTIVE, TYPEUNIX)) != NULL ||
457 (host = dict_mysql_find_host(PLDB, STATACTIVE, TYPEINET)) != NULL) {
459 msg_info(
"%s: found active connection to host %s", myname,
469 while (--count > 0 &&
470 ((host = dict_mysql_find_host(PLDB, STATUNTRIED | STATFAIL,
471 TYPEUNIX)) != NULL ||
472 (host = dict_mysql_find_host(PLDB, STATUNTRIED | STATFAIL,
473 TYPEINET)) != NULL)) {
475 msg_info(
"%s: attempting to connect to host %s", myname,
477 plmysql_connect_single(dict_mysql, host);
478 if (host->stat == STATACTIVE)
488 static void dict_mysql_event(
int unused_event,
void *context)
490 HOST *host = (HOST *) context;
493 plmysql_close_host(host);
503 static int plmysql_query(DICT_MYSQL *dict_mysql,
509 MYSQL_RES *first_result = 0;
515 #define SET_ERROR_AND_WARN_ONCE(err, ...) \
519 msg_warn(__VA_ARGS__); \
523 while ((host = dict_mysql_get_active(dict_mysql)) != NULL) {
525 #if defined(MYSQL_VERSION_ID) && MYSQL_VERSION_ID >= 40000
531 dict_mysql->active_host = host;
535 name, 0, query, dict_mysql_quote);
536 dict_mysql->active_host = 0;
545 if (mysql_query(host->db,
vstring_str(query)) != 0) {
548 dict_mysql->dict.type, dict_mysql->dict.name,
549 mysql_error(host->db));
559 MYSQL_RES *temp_result;
564 if ((temp_result = mysql_store_result(host->db)) != 0) {
565 if (first_result == 0) {
566 first_result = temp_result;
568 SET_ERROR_AND_WARN_ONCE(query_error,
569 "%s:%s: query failed: multiple result sets "
570 "returning data are not supported",
571 dict_mysql->dict.type,
572 dict_mysql->dict.name);
573 mysql_free_result(temp_result);
581 else if (mysql_field_count(host->db) != 0) {
582 SET_ERROR_AND_WARN_ONCE(query_error,
583 "%s:%s: query failed (mysql_store_result): %s",
584 dict_mysql->dict.type,
585 dict_mysql->dict.name,
586 mysql_error(host->db));
592 if ((next_res_status = mysql_next_result(host->db)) > 0) {
593 SET_ERROR_AND_WARN_ONCE(query_error,
594 "%s:%s: query failed (mysql_next_result): %s",
595 dict_mysql->dict.type,
596 dict_mysql->dict.name,
597 mysql_error(host->db));
599 }
while (next_res_status == 0);
604 if (first_result == 0 && dict_mysql->require_result_set) {
605 SET_ERROR_AND_WARN_ONCE(query_error,
606 "%s:%s: query failed: query returned no result set"
607 "(require_result_set = yes)",
608 dict_mysql->dict.type,
609 dict_mysql->dict.name);
617 plmysql_down_host(host);
621 mysql_free_result(first_result);
626 msg_info(
"%s:%s: successful query result from host %s",
627 dict_mysql->dict.type, dict_mysql->dict.name,
635 *result = first_result;
636 return (query_error == 0);
644 static void plmysql_connect_single(DICT_MYSQL *dict_mysql, HOST *host)
646 if ((host->db = mysql_init(NULL)) == NULL)
647 msg_fatal(
"dict_mysql: insufficient memory");
648 if (dict_mysql->option_file)
649 mysql_options(host->db, MYSQL_READ_DEFAULT_FILE, dict_mysql->option_file);
650 if (dict_mysql->option_group && dict_mysql->option_group[0])
651 mysql_options(host->db, MYSQL_READ_DEFAULT_GROUP, dict_mysql->option_group);
652 #if defined(MYSQL_VERSION_ID) && MYSQL_VERSION_ID >= 40000
653 if (dict_mysql->tls_key_file || dict_mysql->tls_cert_file ||
654 dict_mysql->tls_CAfile || dict_mysql->tls_CApath || dict_mysql->tls_ciphers)
655 mysql_ssl_set(host->db,
656 dict_mysql->tls_key_file, dict_mysql->tls_cert_file,
657 dict_mysql->tls_CAfile, dict_mysql->tls_CApath,
658 dict_mysql->tls_ciphers);
659 #if MYSQL_VERSION_ID >= 50023
660 if (dict_mysql->tls_verify_cert != -1)
661 mysql_options(host->db, MYSQL_OPT_SSL_VERIFY_SERVER_CERT,
662 &dict_mysql->tls_verify_cert);
665 if (mysql_real_connect(host->db,
666 (host->type == TYPEINET ? host->name : 0),
667 dict_mysql->username,
668 dict_mysql->password,
671 (host->type == TYPEUNIX ? host->name : 0),
672 CLIENT_MULTI_RESULTS)) {
674 msg_info(
"dict_mysql: successful connection to host %s",
676 host->stat = STATACTIVE;
678 msg_warn(
"connect to mysql server %s: %s",
679 host->hostname, mysql_error(host->db));
680 plmysql_down_host(host);
685 static void plmysql_close_host(HOST *host)
687 mysql_close(host->db);
689 host->stat = STATUNTRIED;
696 static void plmysql_down_host(HOST *host)
698 mysql_close(host->db);
700 host->ts = time((time_t *) 0) + RETRY_CONN_INTV;
701 host->stat = STATFAIL;
707 static void mysql_parse_config(DICT_MYSQL *dict_mysql,
const char *mysqlcf)
709 const char *myname =
"mysql_parse_config";
714 dict_mysql->username =
cfg_get_str(p,
"user",
"", 0, 0);
715 dict_mysql->password =
cfg_get_str(p,
"password",
"", 0, 0);
716 dict_mysql->dbname =
cfg_get_str(p,
"dbname",
"", 1, 0);
717 dict_mysql->result_format =
cfg_get_str(p,
"result_format",
"%s", 1, 0);
718 dict_mysql->option_file =
cfg_get_str(p,
"option_file", NULL, 0, 0);
719 dict_mysql->option_group =
cfg_get_str(p,
"option_group",
"client", 0, 0);
720 #if defined(MYSQL_VERSION_ID) && MYSQL_VERSION_ID >= 40000
721 dict_mysql->tls_key_file =
cfg_get_str(p,
"tls_key_file", NULL, 0, 0);
722 dict_mysql->tls_cert_file =
cfg_get_str(p,
"tls_cert_file", NULL, 0, 0);
723 dict_mysql->tls_CAfile =
cfg_get_str(p,
"tls_CAfile", NULL, 0, 0);
724 dict_mysql->tls_CApath =
cfg_get_str(p,
"tls_CApath", NULL, 0, 0);
725 dict_mysql->tls_ciphers =
cfg_get_str(p,
"tls_ciphers", NULL, 0, 0);
726 #if MYSQL_VERSION_ID >= 50023
727 dict_mysql->tls_verify_cert =
cfg_get_bool(p,
"tls_verify_cert", -1);
730 dict_mysql->require_result_set =
cfg_get_bool(p,
"require_result_set", 1);
736 dict_mysql->expansion_limit =
cfg_get_int(dict_mysql->parser,
737 "expansion_limit", 0, 0, 0);
739 if ((dict_mysql->query =
cfg_get_str(p,
"query", NULL, 0, 0)) == 0) {
755 dict_mysql->query, 1);
756 (void)
db_common_parse(0, &dict_mysql->ctx, dict_mysql->result_format, 0);
773 if (dict_mysql->hosts->argc == 0) {
777 msg_info(
"%s: %s: no hostnames specified, defaulting to '%s'",
778 myname, mysqlcf, dict_mysql->hosts->argv[0]);
787 DICT_MYSQL *dict_mysql;
793 if (open_flags != O_RDONLY)
795 "%s:%s map requires O_RDONLY access mode",
803 "open %s: %m", name));
807 dict_mysql->dict.
lookup = dict_mysql_lookup;
808 dict_mysql->dict.close = dict_mysql_close;
809 dict_mysql->dict.flags = dict_flags;
810 dict_mysql->parser = parser;
811 mysql_parse_config(dict_mysql, name);
812 #if defined(MYSQL_VERSION_ID) && MYSQL_VERSION_ID >= 40000
813 dict_mysql->active_host = 0;
815 dict_mysql->pldb = plmysql_init(dict_mysql->hosts);
816 if (dict_mysql->pldb == NULL)
817 msg_fatal(
"couldn't initialize pldb!\n");
826 static PLMYSQL *plmysql_init(
ARGV *hosts)
831 if ((PLDB = (PLMYSQL *)
mymalloc(
sizeof(PLMYSQL))) == 0)
834 PLDB->len_hosts = hosts->
argc;
835 if ((PLDB->db_hosts = (HOST **)
mymalloc(
sizeof(HOST *) * hosts->
argc)) == 0)
837 for (i = 0; i < hosts->
argc; i++)
838 PLDB->db_hosts[i] = host_init(hosts->
argv[i]);
845 static HOST *host_init(
const char *hostname)
847 const char *myname =
"mysql host_init";
848 HOST *host = (HOST *)
mymalloc(
sizeof(HOST));
849 const char *d = hostname;
853 host->hostname =
mystrdup(hostname);
855 host->stat = STATUNTRIED;
862 if (strncmp(d,
"unix:", 5) == 0) {
864 host->type = TYPEUNIX;
866 if (strncmp(d,
"inet:", 5) == 0)
868 host->type = TYPEINET;
873 if (
strcasecmp(host->name,
"localhost") == 0) {
877 host->type = TYPEUNIX;
880 msg_info(
"%s: host=%s, port=%d, type=%s", myname,
881 host->name ? host->name :
"localhost",
882 host->port, host->type == TYPEUNIX ?
"unix" :
"inet");
888 static void dict_mysql_close(
DICT *dict)
890 DICT_MYSQL *dict_mysql = (DICT_MYSQL *) dict;
892 plmysql_dealloc(dict_mysql->pldb);
894 myfree(dict_mysql->username);
895 myfree(dict_mysql->password);
896 myfree(dict_mysql->dbname);
897 myfree(dict_mysql->query);
898 myfree(dict_mysql->result_format);
899 if (dict_mysql->option_file)
900 myfree(dict_mysql->option_file);
901 if (dict_mysql->option_group)
902 myfree(dict_mysql->option_group);
903 #if defined(MYSQL_VERSION_ID) && MYSQL_VERSION_ID >= 40000
904 if (dict_mysql->tls_key_file)
905 myfree(dict_mysql->tls_key_file);
906 if (dict_mysql->tls_cert_file)
907 myfree(dict_mysql->tls_cert_file);
908 if (dict_mysql->tls_CAfile)
909 myfree(dict_mysql->tls_CAfile);
910 if (dict_mysql->tls_CApath)
911 myfree(dict_mysql->tls_CApath);
912 if (dict_mysql->tls_ciphers)
913 myfree(dict_mysql->tls_ciphers);
915 if (dict_mysql->hosts)
925 static void plmysql_dealloc(PLMYSQL *PLDB)
929 for (i = 0; i < PLDB->len_hosts; i++) {
931 if (PLDB->db_hosts[i]->db)
932 mysql_close(PLDB->db_hosts[i]->db);
933 myfree(PLDB->db_hosts[i]->hostname);
934 if (PLDB->db_hosts[i]->name)
935 myfree(PLDB->db_hosts[i]->name);
936 myfree((
void *) PLDB->db_hosts[i]);
938 myfree((
void *) PLDB->db_hosts);
int find_inet_port(const char *service, const char *protocol)
char * mystrdup(const char *str)
ARGV * argv_free(ARGV *argvp)
NORETURN msg_panic(const char *fmt,...)
void db_common_sql_build_query(VSTRING *query, CFG_PARSER *parser)
void argv_add(ARGV *argvp,...)
int cfg_get_int(const CFG_PARSER *parser, const char *name, int defval, int min, int max)
#define DICT_FLAG_FOLD_FIX
void db_common_free_ctx(void *ctxPtr)
VSTRING * vstring_strcpy(VSTRING *vp, const char *src)
#define VSTRING_TERMINATE(vp)
int db_common_dict_partial(void *ctxPtr)
void db_common_parse_domain(CFG_PARSER *parser, void *ctxPtr)
int db_common_expand(void *ctxArg, const char *format, const char *value, const char *key, VSTRING *result, db_quote_callback_t quote_func)
#define VSTRING_RESET(vp)
void msg_warn(const char *fmt,...)
VSTRING * vstring_alloc(ssize_t len)
const char * username(void)
int db_common_check_domain(void *ctxPtr, const char *addr)
char * lowercase(char *string)
const char *(* lookup)(struct DICT *, const char *)
DICT * dict_mysql_open(const char *, int, int)
NORETURN msg_fatal(const char *fmt,...)
#define DICT_ERR_VAL_RETURN(dict, err, val)
char * cfg_get_str(const CFG_PARSER *parser, const char *name, const char *defval, int min, int max)
#define DICT_FLAG_PATTERN
ARGV * argv_split(const char *, const char *)
CFG_PARSER * cfg_parser_free(CFG_PARSER *parser)
#define VSTRING_SPACE(vp, len)
int strcasecmp(const char *s1, const char *s2)
CFG_PARSER * cfg_parser_alloc(const char *pname)
VSTRING * vstring_free(VSTRING *vp)
time_t event_request_timer(EVENT_NOTIFY_TIME_FN callback, void *context, int delay)
int db_common_parse(DICT *dict, void **ctxPtr, const char *format, int query)
void(* db_quote_callback_t)(DICT *, const char *, VSTRING *)
DICT * dict_alloc(const char *, const char *, ssize_t)
int cfg_get_bool(const CFG_PARSER *parser, const char *name, int defval)
int event_cancel_timer(EVENT_NOTIFY_TIME_FN callback, void *context)
char * split_at_right(char *string, int delimiter)
#define cfg_get_owner(cfg)
char * vstring_export(VSTRING *vp)
DICT * dict_surrogate(const char *dict_type, const char *dict_name, int open_flags, int dict_flags, const char *fmt,...)
void * mymalloc(ssize_t len)
void argv_terminate(ARGV *argvp)
void msg_info(const char *fmt,...)