Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(dns): simplify the code of new dns utils #13398

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions kong/dns/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ local type = type
local ipairs = ipairs
local tonumber = tonumber
local math_random = math.random
local table_new = require("table.new")
local table_clear = require("table.clear")
local table_insert = table.insert
local table_remove = table.remove
Expand All @@ -32,12 +33,9 @@ local LOCALHOST = {
local DEFAULT_HOSTS = { localhost = LOCALHOST, }


local _M = {}


-- checks the hostname type
-- @return "ipv4", "ipv6", or "domain"
function _M.hostname_type(name)
local function hostname_type(name)
local remainder, colons = name:gsub(":", "")
if colons > 1 then
return "ipv6"
Expand All @@ -55,8 +53,8 @@ end
-- IPv6 addresses are always returned in square brackets
-- @param name the string to check (this may contain a port number)
-- @return `name/ip` + `port (or nil)` + `type ("ipv4", "ipv6" or "domain")`
function _M.parse_hostname(name)
local t = _M.hostname_type(name)
local function parse_hostname(name)
local t = hostname_type(name)
if t == "ipv4" or t == "domain" then
local ip, port = name:match("^([^:]+)%:*(%d*)$")
return ip, tonumber(port), t
Expand All @@ -81,7 +79,7 @@ local function get_lines(path)
end


function _M.parse_hosts(path, enable_ipv6)
local function parse_hosts(path, enable_ipv6)
local lines, err = get_lines(path or DEFAULT_HOSTS_FILE)
if not lines then
log(NOTICE, "Invalid hosts file: ", err)
Expand All @@ -105,7 +103,7 @@ function _M.parse_hosts(path, enable_ipv6)

-- Check if the line contains an IP address followed by hostnames
if n >= 2 then
local ip, _, family = _M.parse_hostname(parts[1])
local ip, _, family = parse_hostname(parts[1])

if family ~= "domain" then -- ipv4/ipv6
for i = 2, n do
Expand All @@ -132,7 +130,7 @@ end


-- TODO: need to rewrite it instead of calling parseResolvConf from the old library
function _M.parse_resolv_conf(path, enable_ipv6)
local function parse_resolv_conf(path, enable_ipv6)
local resolv, err = utils.parseResolvConf(path or DEFAULT_RESOLV_CONF)
if not resolv then
return nil, err
Expand Down Expand Up @@ -161,14 +159,16 @@ function _M.parse_resolv_conf(path, enable_ipv6)

-- nameservers
if resolv.nameserver then
local n = 0
local nameservers = {}

for _, address in ipairs(resolv.nameserver) do
local ip, port, t = utils.parseHostname(address)
if t == "ipv4" or
(t == "ipv6" and not ip:find([[%]], nil, true) and enable_ipv6)
then
table_insert(nameservers, port and { ip, port } or ip)
n = n + 1
nameservers[n] = port and { ip, port } or ip
end
end

Expand All @@ -179,7 +179,7 @@ function _M.parse_resolv_conf(path, enable_ipv6)
end


function _M.is_fqdn(name, ndots)
local function is_fqdn(name, ndots)
if name:sub(-1) == "." then
return true
end
Expand All @@ -191,33 +191,35 @@ end


-- check if it matchs the SRV pattern: _<service>._<proto>.<name>
function _M.is_srv(name)
local function is_srv(name)
return name:match("^_[^._]+%._[^._]+%.[^.]+") ~= nil
end


-- construct names from resolv options: search, ndots and domain
function _M.search_names(name, resolv, hosts)
if not resolv.search or _M.is_fqdn(name, resolv.ndots) or
local function search_names(name, resolv, hosts)
local resolv_search = resolv.search

if not resolv_search or is_fqdn(name, resolv.ndots) or
(hosts and hosts[name])
then
return { name }
end

local names = {}
local count = #resolv_search
local names = table_new(count + 1, 0)

for _, suffix in ipairs(resolv.search) do
table_insert(names, name .. "." .. suffix)
for i = 1, count do
names[i] = name .. "." .. resolv_search[i]
end

table_insert(names, name) -- append the original name at last
names[count + 1] = name -- append the original name at last

return names
end


-- add square brackets around IPv6 addresses if a non-strict check detects them
function _M.ipv6_bracket(name)
local function ipv6_bracket(name)
if name:match("^[^[].*:") then -- not start with '[' and contains ':'
return "[" .. name .. "]"
end
Expand All @@ -228,13 +230,14 @@ end

-- util APIs to balance @answers

function _M.get_next_round_robin_answer(answers)
local function get_next_round_robin_answer(answers)
answers.last = (answers.last or 0) % #answers + 1

return answers[answers.last]
end


local get_next_weighted_round_robin_answer
do
-- based on the Nginx's SWRR algorithm and lua-resty-balancer
local function swrr_next(answers)
Expand Down Expand Up @@ -296,7 +299,7 @@ do
end


function _M.get_next_weighted_round_robin_answer(answers)
get_next_weighted_round_robin_answer = function(answers)
local l = answers.lowest_prio_records or filter_lowest_priority_answers(answers)

-- perform round robin selection on lowest priority answers @l
Expand All @@ -309,4 +312,15 @@ do
end


return _M
return {
hostname_type = hostname_type,
parse_hostname = parse_hostname,
parse_hosts = parse_hosts,
parse_resolv_conf = parse_resolv_conf,
is_fqdn = is_fqdn,
is_srv = is_srv,
search_names = search_names,
ipv6_bracket = ipv6_bracket,
get_next_round_robin_answer = get_next_round_robin_answer,
get_next_weighted_round_robin_answer = get_next_weighted_round_robin_answer,
}
Loading