柴少的官方网站 技术在学习中进步,水平在分享中升华

Nginx结合Lua实现二次验证(二)

好了紧跟上文,上一篇 http://www.51niux.com/?id=323  我们已经详细的了解了Nginx结合lua的一些简单用法,下面我们用一个现实中比较常见的例子来继续了解。

#我先描述一下场景啊,二次验证已经非常的司空见惯了啊,比如你登录阿里云腾讯云这些需要绑定MFA并且每次登录的时候都需要输入一个基于时间的6位数字,当然这种方式也是企业内部也是被广泛使用的,比如我们有些系统由于场景需要比如我们有很多分城市,需要公网访问一些个别的系统,你想除了账号密码认证外,还想结合每个人的身份分配一个唯一的秘钥字符串,以此生成一个6位验证数字,只有登录通过后才能1天内使用(当然也可以通过阿里云的sase等一些安全产品解决公网访问的问题)。

#比如你现在就是内网访问,但是你还是想多一层验证,比如每次登录堡垒机等跟权限相关的系统,除了账号密码外,还是想让用手机上的mfa验证一下(每个人都有内部账号,每个账号都会入职的时候分配一个唯一的字符串,用于结合当前时间产生一个变化只有30秒生命周期的的6位数字)。

#好了上面是一个实际场景的大概描述,下面就让我们一步步实现它吧

一、先用脚本产生一个固定字符串的随机验证

#vi totp_generator.lua

-- 导入bit32库,用于位操作,位操作就像搭积木,我们可以把数字拆成二进制的小块,然后移动或组合它们
-- Lua5.1中需要使用bit库,Lua5.2+可以直接使用bit32
local bit32 = require("bit")

-- Base32解码函数(阿里云MFA密钥通常为Base32编码)
-- Base32就像是一种密码本,用32个字母和数字表示二进制数据
local function base32_decode(base32_str)
    -- 移除所有空格并转为大写,就像整理信件一样,我们先去掉多余的空格,并且统一字母大小写
    base32_str = base32_str:gsub("%s", ""):upper()
    
    -- 检查输入长度,如果没有内容,就像空信封一样,我们无法处理
    if #base32_str == 0 then
        error("Base32字符串不能为空")
    end
    
    -- 检查是否有填充字符'=',Base32编码会用'='符号填充,就像拼图最后用空白块补齐一样
    local padding = base32_str:match("=+$")
    if padding then
        -- 移除字符串末尾的填充字符,从1开始,#base32_str:获取原字符串的长度,#padding:获取填充字符子串的长度
        -- 相减后得到的是不包含填充字符的字符串的结束位置,整体效果是:保留原字符串从开头到填充字符之前的部分
        base32_str = base32_str:sub(1, #base32_str - #padding)
    end
    
    -- 验证输入长度(必须是8的倍数),Base32的规则是每8个字符代表5个原始数据字节
    if #base32_str % 8 ~= 0 then
        error("Base32字符串长度必须是8的倍数")
    end
    
    -- Base32字符集定义,这是Base32编码的"密码本",每个字母和数字对应一个0-31的数字
    local base32_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
    local result = ""       -- 存储解码后的结果
    local bits = 0          -- 累积的位数
    local value = 0         -- 累积的值
    
    -- 逐字符解码,一个字母一个字母地"翻译"Base32字符串
    for i = 1, #base32_str do
        local c = base32_str:sub(i, i)
        
        -- 验证字符是否有效,如果在密码本里找不到这个字符,就说明是无效的.
        if c == "" then
            error("Base32字符串包含空字符")
        end      
        -- 将字符映射到对应的值(0-31),就像查字典一样,找到字符在密码本中的位置
        local digit = base32_chars:find(c, 1, true)
        if not digit then
            error("无效的Base32字符: " .. c)
        end
        digit = digit - 1   -- 减1是因为Lua的索引从1开始,而我们需要从0开始        
        -- 累积位值,每个Base32字符代表5位二进制数据
        value = bit32.lshift(value, 5) + digit
        bits = bits + 5     -- 增加5位
        
        -- 当累积了至少8位时,我们可以生成一个字节
        if bits >= 8 then
            bits = bits - 8 -- 用掉8位
            -- 提取最左边的8位(相当于取整钱)
            -- bit32.rshift(value, bits)将多余的位右移去掉,bit32.band(..., 0xFF)只保留低8位
            result = result .. string.char(bit32.band(bit32.rshift(value, bits), 0xFF))
        end
    end
    -- 返回解码后的二进制数据
    return result
end

-- SHA-1的常量值,SHA-1算法中的四个固定常量
local K = {0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xCA62C1D6}

-- 左旋转函数,将一个32位数字向左旋转指定的位数
local function rotl32(n, b)
    -- 先将n左移b位,然后将移出的部分放到右边
    -- 例如:rotl32(0001, 1) = 0010 (左移) + 0000 (右移) = 0010
    return bit32.band(bit32.lshift(n, b), 0xFFFFFFFF) + bit32.rshift(n, 32 - b)
end

-- 将消息填充到512比特的倍数,SHA-1要求输入是512位(64字节)的倍数,就像把信折成特定大小才能放进信封
local function pad_message(message)
    local length_bits = #message * 8  -- 计算消息总位数
    message = message .. string.char(0x80) -- 添加0x80字节(10000000)
    -- 填充0x00,直到长度为512比特的倍数减64比特,最后64位要留给原始消息的长度
    while (#message * 8) % 512 ~= 448 do
        message = message .. string.char(0x00)
    end   
    -- 添加原始长度(64位,大端序),就像在信封上标注信的原始大小
    local length_bytes = {}
    for i = 7, 0, -1 do
        -- 从高位到低位,依次提取长度的每8位
        local byte_val = bit32.band(math.floor(length_bits / (256 ^ i)), 0xFF)
        table.insert(length_bytes, string.char(byte_val))
    end    
    message = message .. table.concat(length_bytes)
    return message
end

-- 将512比特的消息块转换为16个32位字,把大消息分成小块处理
local function words_from_block(block)
    local words = {}
    for j = 1, 16 do
        -- 每4个字节组成一个32位字
        local byte1, byte2, byte3, byte4 = block:byte((j - 1) * 4 + 1, (j - 1) * 4 + 4)
        byte2 = byte2 or 0  -- 如果不足4个字节,用0填充
        byte3 = byte3 or 0
        byte4 = byte4 or 0
        
        -- 将4个字节合并成一个32位整数,就像把4个小积木拼成一个大积木
        words[j] = bit32.bor(
            bit32.lshift(byte1, 24), -- 第一个字节移到最高位
            bit32.lshift(byte2, 16), -- 第二个字节移到次高位
            bit32.lshift(byte3, 8),  -- 第三个字节移到次低位
            byte4           -- 第四个字节在最低位
        )
    end
    return words
end

-- SHA-1主函数,对消息进行哈希处理,生成160位的哈希值,就像给消息做一个"指纹"
function sha1(message)
    -- 初始哈希值,SHA-1算法定义的5个初始常量,就像拼图的5个角块
    local H0 = 0x67452301
    local H1 = 0xEFCDAB89
    local H2 = 0x98BADCFE
    local H3 = 0x10325476
    local H4 = 0xC3D2E1F0

    -- 填充消息,确保消息长度是512位的倍数
    message = pad_message(message)

    -- 处理每个512比特块,把消息分成多个小块依次处理
    for i = 1, #message, 64 do
        local block = message:sub(i, i + 63)
        local words = words_from_block(block)

        -- 扩展到80个字,从16个字扩展到80个字,增加数据的复杂性
        for t = 17, 80 do
            local word = bit32.bxor(words[t - 3], words[t - 8], words[t - 14], words[t - 16])
            words[t] = rotl32(word, 1)
        end

        -- 初始化工作变量,5个工作变量,用于临时存储计算结果
        local A, B, C, D, E = H0, H1, H2, H3, H4

        -- 主循环 - 80轮处理,对每个扩展后的字进行复杂的位运算
        for t = 1, 80 do
            local temp, f
            local k_index = math.floor((t - 1) / 20) + 1
            
            -- 根据当前轮数选择不同的逻辑函数,就像根据不同的步骤做不同的动作
            if t <= 20 then
                -- 逻辑函数1:(B AND C) OR (NOT B AND D)
                f = bit32.bor(bit32.band(B, C), bit32.band(bit32.bnot(B), D))
            elseif t <= 40 then
                -- 逻辑函数2:B XOR C XOR D
                f = bit32.bxor(B, C, D)
            elseif t <= 60 then
                -- 逻辑函数3:(B AND C) OR (B AND D) OR (C AND D)
                f = bit32.bor(bit32.band(B, C), bit32.band(B, D), bit32.band(C, D))
            else
                -- 逻辑函数4:B XOR C XOR D
                f = bit32.bxor(B, C, D)
            end
            
            -- 计算临时值,结合左旋转、逻辑函数、常量和当前字进行计算
            temp = bit32.band(rotl32(A, 5) + f + E + words[t] + K[k_index], 0xFFFFFFFF)
            
            -- 循环移位工作变量,就像5个人站成一圈,依次传递任务
            E = D
            D = C
            C = rotl32(B, 30)
            B = A
            A = temp
        end

        -- 更新哈希值,将这一轮的计算结果累加到最终哈希值上
        H0 = bit32.band(H0 + A, 0xFFFFFFFF)
        H1 = bit32.band(H1 + B, 0xFFFFFFFF)
        H2 = bit32.band(H2 + C, 0xFFFFFFFF)
        H3 = bit32.band(H3 + D, 0xFFFFFFFF)
        H4 = bit32.band(H4 + E, 0xFFFFFFFF)
    end

    -- 返回最终的哈希值(二进制格式),将5个32位整数转换为20字节的二进制数据
    local hash_bytes = {}
    for i, h in ipairs({H0, H1, H2, H3, H4}) do
        for j = 3, 0, -1 do
            -- 从高位到低位,提取每个整数的4个字节
            local byte_val = bit32.band(bit32.rshift(h, j * 8), 0xFF)
            table.insert(hash_bytes, string.char(byte_val))
        end
    end   
    return table.concat(hash_bytes)
end

-- 定义HMAC-SHA1函数,基于SHA-1的密钥哈希消息认证码,就像给消息加了一把带密码的锁
local function hmac_sha1(key, data)
    -- 如果密钥长度超过64字节,先进行SHA-1哈希,确保密钥长度不超过块大小
    if #key > 64 then
        key = sha1(key)
    end
    
    -- 将密钥填充到64字节,就像把钥匙磨成特定形状
    key = key .. string.rep("\0", 64 - #key)
    
    -- 内部和外部密钥,创建两个不同的密钥变体
    local inner_key = {}
    local outer_key = {}
    
    -- 对密钥的每个字节进行处理,内部密钥:与0x36异或外部密钥:与0x5C异或
    for i = 1, 64 do
        local byte_val = key:byte(i)
        table.insert(inner_key, string.char(bit32.bxor(byte_val, 0x36)))
        table.insert(outer_key, string.char(bit32.bxor(byte_val, 0x5C)))
    end    
    inner_key = table.concat(inner_key)
    outer_key = table.concat(outer_key)
    
    -- 计算内部哈希,先用内部密钥和数据计算哈希
    local inner_hash = sha1(inner_key .. data)
    
    -- 计算外部哈希,再用外部密钥和内部哈希结果计算最终哈希
    return sha1(outer_key .. inner_hash)
end

-- 定义TOTP函数,基于时间的一次性密码算法,就像一个随时间变化的密码锁
local function totp(secret, time_step, digits)
    -- 获取当前时间戳,记录从1970年1月1日到现在的秒数
    local current_time = os.time()
    -- 计算时间步长,将时间分成固定长度的"块"
    local counter = math.floor(current_time / time_step)
    
    -- 将计数器转换为8字节的大端序字节数组,就像把数字翻译成计算机能理解的语言
    local counter_bytes = {}
    for i = 7, 0, -1 do
        local byte_val = bit32.band(math.floor(counter / (256 ^ i)), 0xFF)
        counter_bytes[#counter_bytes + 1] = string.char(byte_val)
    end
    counter_bytes = table.concat(counter_bytes)
    
    -- 计算HMAC-SHA1,使用密钥和计数器生成哈希值
    local hash = hmac_sha1(secret, counter_bytes)
    
    -- 动态偏移量(从哈希结果的最后一个字节的低4位获取),就像从密码锁中随机选择一个位置开始
    local offset = bit32.band(hash:byte(20), 0x0F)
    
    -- 提取4字节数据(从offset位置开始),从哈希结果中提取一部分作为基础
    local binary_code = 0
    for i = 1, 4 do
        local byte_val = hash:byte(offset + i)
        binary_code = bit32.bor(binary_code, bit32.lshift(byte_val, (4 - i) * 8))
    end
    
    -- 移除最高位并取模,确保结果是一个正整数,并限制位数
    binary_code = bit32.band(binary_code, 0x7FFFFFFF)
    local otp = binary_code % (10 ^ digits)
    
    -- 格式化为固定长度的字符串,就像把密码格式化成固定位数的数字
    return string.format("%0" .. digits .. "d", otp)
end

-- 时间同步检查(可选),确保客户端和服务器的时间一致,就像校准两个钟表
local function check_time_synchronization(server_time_url)
    -- 实际应用中,应从服务器获取时间戳进行比对,这里仅为示例,实际实现需要HTTP请求
    print("警告:时间同步检查功能未实现,建议与服务器时间进行比对")
    return true
end

-- 主函数,程序的入口点
local function main()
    -- 阿里云MFA密钥(Base32编码),这是你在阿里云设置MFA时得到的密钥
    local base32_secret = "zABCDEghijklmnopqrstuvwxyzABCDEF"
    
    -- 解码Base32密钥,将密钥从"密码本语言"翻译成计算机能理解的二进制
    local secret = base32_decode(base32_secret)
    
    -- TOTP参数,时间步长:每30秒生成一个新密码,验证码位数:生成6位数字密码
    local time_step = 30  -- 步长为30秒
    local digits = 6      -- 验证码位数
    
    -- 检查时间同步,确保你的电脑时间准确
    if check_time_synchronization() then
        -- 生成当前验证码
        -- 基于当前时间和密钥计算出一次性密码
        local otp = totp(secret, time_step, digits)
        print("当前时间戳:", os.time())
        print("TOTP验证码:", otp)
    else
        print("错误:本地时间与服务器时间不同步,请调整系统时间")
    end
end

-- 执行主函数,启动整个程序
main()


# /usr/local/luajit/bin/luajit totp_generator.lua   #手工执行一下(用手机阿里云或者谷歌验证器绑一下zABCDEghijklmnopqrstuvwxyzABCDEF核对一下) 

警告:时间同步检查功能未实现,建议与服务器时间进行比对
当前时间戳:	1749723585
TOTP验证码:	097080

二、用标准目录形式为用户产生密钥串

#上面的例子我们已经可以用单一的秘钥生成了6位验证码,现实场景呢,一般也不会让用户自行生成秘钥,一般都是系统自动分配的唯一秘钥串,字符串长度呢就是8的倍数。你像阿里面向用户比较多所以它比较长64位,我们就给内部员工使用不需要那么长16位就可以了。

2.1 准备mysql环境

#这里我们假设用户注册后的用户信息是存储到数据库中的(很多时候新员工入职的时候所有信息会存在一个指定的库表中,然后其他功能再依此延伸,你比如有个字段会自动创建一个16位的秘钥字符串,而我们就需要用户字母名称和秘钥串这两个字段,不管是定时任务式的还是触发式的很多时候我们会把这个信息存储到文本或者redis缓存中便于加载)。

MariaDB [(none)]> create database auth_db;
MariaDB [(none)]> grant all privileges on auth_db.* to 'auth'@'127.0.0.1' identified by 'auth123';
MariaDB [(none)]> flush privileges;

2.2 准备相关lua文件

# mkdir /usr/local/nginx/conf/auth

# vim /usr/local/nginx/conf/auth/config.lua   #创建配置文件

-- 系统配置信息,创建了一个局部变量名为config的表,表Lua中最常用的数据结构,类似于其他语言中的字典、哈希表或对象
local config = {
    -- 数据库配置
    db = {
        host = "127.0.0.1", -- MySQL服务器地址
        port = 3306,  -- 端口
        database = "auth_db", -- 数据库名
        user = "auth",  -- 用户名
        password = "auth123", -- 密码
        max_idle_timeout = 10000, -- 连接池空闲超时(毫秒)
        pool_size = 100    -- 连接池大小
    },
    
    -- 二次验证配置
    otp = {
        code_length = 6, -- 验证码长度
        time_step = 30,  -- 时间步长(秒)
        window_size = 1,  -- 验证窗口(前后允许的时间步数)
        secret_length = 16 -- 密钥长度(字节)
    },
    
    -- 登录限制配置
    login_attempts = {
        max_attempts = 10000, -- 最大尝试次数
        lockout_time = 60 -- 锁定时间(秒)
    },
    
    -- 会话配置
    session = {
        expires_in = 86400, -- 会话有效期(秒)
        cookie_name = "auth_token" -- Cookie名称
    },
    
    -- IP白名单配置
    whitelist_file = "/usr/local/nginx/conf/auth/ip_whitelist.txt",  -- 白名单文件路径
    
    -- 路径配置
    login_path = "/verify",       -- 登录页面路径
    register_path = "/register"   -- 注册页面路径
}
-- return语句用于将模块的内容导出,使其他模块可以通过 require函数加载和使用这些内容,将这个配置表导出为模块的公共接口
return config

# vim /usr/local/nginx/conf/auth/db.lua  #创建数据库操作文件

-- 数据库操作模块
local mysql = require("resty.mysql")
local config = require("config").db

local _M = {}

-- 创建数据库连接
function _M.connect()
    local db, err = mysql:new()
    if not db then
        return nil, "创建MySQL实例失败: " .. err
    end
    
    -- 设置超时
    db:set_timeout(1000)  -- 1秒超时
    
    -- 连接数据库
    local ok, err, errno, sqlstate = db:connect({
        host = config.host,
        port = config.port,
        database = config.database,
        user = config.user,
        password = config.password,
        max_packet_size = config.max_packet_size
    })
    
    if not ok then
        return nil, string.format("连接数据库失败: %s, 错误码: %d, SQL状态: %s", err, errno, sqlstate)
    end   
    return db
end

-- 安全执行SQL查询
-- _M是模块的命名空间,代表当前模块,execute_query是函数名,sql是传入的SQL查询语句,字符串类型
function _M.execute_query(sql)
    -- 调用模块内部的connect函数创建数据库连接,如果连接失败(db为nil),返回nil和错误信息
    local db, err = _M.connect()
    if not db then
        return nil, err
    end
    
    -- 执行查询,db:query(sql)是resty.mysql库提供的执行SQL查询的方法,其中errno:MySQL错误码(失败时),sqlstate:SQL状态码(失败时)
    local res, err, errno, sqlstate = db:query(sql)
    
    -- 将连接放回连接池,config.max_idle_timeout:连接在池中保持空闲的最长时间(ms),config.pool_size:连接池的最大大小
    db:set_keepalive(config.max_idle_timeout, config.pool_size)
    
    -- 如果查询失败(res为nil),返回nil和格式化的错误信息
    if not res then
        return nil, string.format("执行SQL失败: %s, 错误码: %d, SQL状态: %s", err, errno, sqlstate)
    end   
    -- 如果查询成功,返回查询结果res
    return res
end

-- 安全转义SQL字符串(防止SQL注入),将任意类型的值转换为适合在SQL语句中使用的安全字符串表示形式
function _M.escape_literal(str)
    -- 在SQL中,nil对应NULL关键字,返回字符串"NULL"(不带引号),使其能直接在SQL中作为空值使用
    if str == nil then
        return "NULL"
    end
    
    -- 处理布尔值,SQL中通常使用1和0表示布尔值true和false,将布尔值转换为对应的数字字符串
    if type(str) == "boolean" then
        return str and "1" or "0"
    end
    
    -- 处理数字,直接将数字转换为字符串,保持其数值表示
    if type(str) == "number" then
        return tostring(str)
    end
    
    -- 处理字符串确保输入值被视为字符串,用单引号将转义后的字符串包围,形成合法的SQL字符串字面量
    str = tostring(str)
    -- 转义单引号和反斜杠
    str = string.gsub(str, "\\", "\\\\")
    str = string.gsub(str, "'", "''")
    return "'" .. str .. "'"
end

-- 初始化数据库表
function _M.init_tables()
    -- 创建用户表
    local create_users_sql = [[
        CREATE TABLE IF NOT EXISTS users (
            id INT AUTO_INCREMENT PRIMARY KEY,
            username VARCHAR(50) UNIQUE NOT NULL,
            otp_secret VARCHAR(32) NOT NULL,
            login_attempts INT DEFAULT 0,
            last_attempt_time DATETIME,
            created_at DATETIME DEFAULT CURRENT_TIMESTAMP
        )
    ]]
    
    local res, err = _M.execute_query(create_users_sql)
    if not res then
        return false, "创建用户表失败: " .. err
    end
    
    -- 创建会话表
    local create_sessions_sql = [[
        CREATE TABLE IF NOT EXISTS sessions (
            session_id VARCHAR(64) PRIMARY KEY,
            user_id INT NOT NULL,
            expires_at DATETIME NOT NULL,
            created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
            FOREIGN KEY (user_id) REFERENCES users(id)
        )
    ]]
    
    local res, err = _M.execute_query(create_sessions_sql)
    if not res then
        return false, "创建会话表失败: " .. err
    end
    
    return true
end
return _M

# vim /usr/local/nginx/conf/auth/utils.lua  #工具函数模块,提供通用功能的文件

local config = require("config")

local _M = {}
-- 定义模块的公共函数generate_random_string,接受一个参数length表示要生成的随机字符串长度 
function _M.generate_random_string(length)
    -- 初始化空字符串result用于存储生成的随机字符串
    local result = ""
    -- 定义字符集chars,包含所有可能被使用的字符(大小写字母和数字)。
    local chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

    -- 通过执行系统命令echo $$获取当前进程 ID。io.popen()打开一个管道执行命令,read()读取输出,tonumber()转换为数字。如果获取失败则默认为0
    local pid = tonumber(io.popen("echo $$"):read()) or 0

    -- 使用时间和进程ID组合作为种子,结合进程ID可以避免同一秒内生成相同的随机序列
    math.randomseed(os.time() + pid)

    -- 预调用math.random()三次,丢弃结果。这是Lua中提高随机数质量的常见做法,因为首次调用往往不够随机。
    for _ = 1, 3 do
        math.random()
    end
    
    -- ngx.log(ngx.INFO, "进入生成验证码阶段")
    -- 上面是打印一下内容,在排错阶段很有用,在开发阶段多加一些输出,可以知道代码在哪一步出错了,需要就打开注释
    -- 循环length次,每次生成一个1到字符集长度之间的随机索引
    for i = 1, length do
        -- 只调用一次math.random,获取单个字符
        local random_index = math.random(1, #chars)
        -- 使用sub()方法从字符集中提取对应位置的单个字符,并追加到result中
        result = result .. chars:sub(random_index, random_index)
    end
    -- 返回生成的随机字符串
    return result
end

-- 解析HTTP请求参数
function _M.parse_args()
    local args = {}
    
    -- 解析GET参数
    if ngx.var.request_method == "GET" then
        local query = ngx.var.query_string
        if query and query ~= "" then
            for k, v in string.gmatch(query, "([^&=]+)=([^&=]*)") do
                args[k] = ngx.unescape_uri(v)
            end
        end
    -- 解析POST参数
    elseif ngx.var.request_method == "POST" then
        ngx.req.read_body()
        local post_data = ngx.req.get_post_data()
        if post_data and post_data ~= "" then
            for k, v in string.gmatch(post_data, "([^&=]+)=([^&=]*)") do
                args[k] = ngx.unescape_uri(v)
            end
        end
    end
    
    return args
end

-- 检查字符串是否为空
function _M.is_empty(str)
    return str == nil or str == ""
end
-- 将模块表_M作为结果返回,使外部代码可以访问其公共函数
return _M

# vim /usr/local/nginx/conf/auth/otp.lua #注册用户产生随机密钥的文件

-- 导入三个外部模块
local db = require("db")
local config = require("config")
local utils = require("utils")

local _M = {}

-- 定义generate_secret函数,用于生成 OTP 密钥
function _M.generate_secret()
    -- 调用utils模块的generate_random_string函数,生成长度为config.otp.secret_length的随机字符串
    return utils.generate_random_string(config.otp.secret_length)
end

-- 注册新用户(生成OTP密钥)
function _M.register_user(username)
    -- 首先调用generate_secret生成OTP密钥
    local otp_secret = _M.generate_secret()

    -- 验证生成的密钥长度是否超过 32 个字符(假设数据库字段类型为 VARCHAR (32))
    if #otp_secret > 32 then
        --如果超过则截断为前 32 个字符,防止数据库插入错误
        otp_secret = string.sub(otp_secret, 1, 32) 
    end
    -- ngx.log(ngx.INFO,"otp_secret:",otp_secret)
    
    -- 构建SQL插入语句,将用户名和OTP密钥存入users表
    local sql = string.format([[
        INSERT INTO users (username, otp_secret)
        VALUES (%s, %s)
    ]],
    -- 使用db.escape_literal对输入进行转义,防止SQL注入攻击
    db.escape_literal(username),
    db.escape_literal(otp_secret)
    )
    -- 执行SQL插入语句
    local res, err = db.execute_query(sql)
    if not res then
        return false, nil, "注册失败: " .. err
    end
    
    return true, otp_secret, "注册成功"
end

return _M

#vim /usr/local/nginx/conf/auth/main.lua  #主程序入口文件

-- ngx是OpenResty提供的全局对象,用于访问Nginx API
local ngx = ngx
local otp = require("otp")
local db = require("db")
local utils = require("utils")
local config = require("config")

-- 调用db.init_tables()初始化数据库表(如果不存在),失败时返回500错误并终止请求处理
local ok, err = db.init_tables()
if not ok then
    -- ngx.log(ngx.ERR, "数据库连接失败: ", err)
    ngx.status = 500
    ngx.say("数据库初始化失败: ", err)
    ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR)
end

-- 显示登录弹窗
local function show_login_popup(error_msg, original_url)
    error_msg = error_msg or ""
    original_url = original_url or ngx.var.request_uri
    
    -- 创建HTML模板
    local html = [[
    <!DOCTYPE html>
    <html lang="zh-CN">
    <head>
        <meta charset="UTF-8">
        <title>二次验证</title>
        <style>
            body { font-family: Arial, sans-serif; }
            .overlay {
                position: fixed;
                top: 0;
                left: 0;
                width: 100%;
                height: 100%;
                background-color: rgba(0,0,0,0.5);
                display: flex;
                justify-content: center;
                align-items: center;
                z-index: 9999;
            }
            .popup {
                background-color: white;
                padding: 20px;
                border-radius: 5px;
                box-shadow: 0 0 10px rgba(0,0,0,0.3);
                max-width: 400px;
                width: 100%;
                animation: fadeIn 0.3s ease-in-out;
            }
            @keyframes fadeIn {
                from { opacity: 0; transform: scale(0.9); }
                to { opacity: 1; transform: scale(1); }
            }
            .form-group { margin-bottom: 15px; }
            .form-group label { display: block; margin-bottom: 5px; font-weight: bold; }
            .form-group input { 
                width: 100%; 
                padding: 10px; 
                border: 1px solid #ddd; 
                border-radius: 4px; 
                box-sizing: border-box;
            }
            button { 
                background-color: #4CAF50; 
                color: white; 
                padding: 10px 15px; 
                border: none; 
                border-radius: 4px; 
                cursor: pointer; 
                width: 100%;
                font-size: 16px;
            }
            button:hover { background-color: #45a049; }
            .error { 
                color: red; 
                margin-bottom: 10px; 
                padding: 8px; 
                background-color: #ffebee; 
                border-radius: 4px;
            }
            .info {
                margin-top: 15px;
                font-size: 14px;
                color: #666;
            }
        </style>
    </head>
    <body>
        <div class="overlay">
            <div class="popup">
                <h2>二次验证</h2>
                {{error_message}}
                <p>您需要进行二次验证才能访问此资源</p>
                <form method="post" action="{{action_url}}">
                    <input type="hidden" name="original_url" value="{{original_url}}" />
                    <div class="form-group">
                        <label for="username">用户名:</label>
                        <input type="text" id="username" name="username" required placeholder="输入您的用户名">
                    </div>
                    <div class="form-group">
                        <label for="code">二次验证码:</label>
                        <input type="text" id="code" name="code" required placeholder="输入6位数字验证码">
                    </div>
                    <button type="submit">验证并继续</button>
                </form>
                <div class="info">
                    <p>提示: 验证码每30秒更新一次</p>
                    <p>测试用注册: <a href="/register?username=test">/register?username=test</a></p>
                </div>
            </div>
        </div>
    </body>
    </html>
    ]]
    
    -- 安全替换函数
    local function safe_replace(html, placeholder, value)
        -- 确保替换值中的特殊字符被正确处理
        value = string.gsub(value, "%%", "%%%%")
        return string.gsub(html, placeholder, value)
    end
    
    -- 使用安全替换函数
    html = safe_replace(html, "{{error_message}}", 
    error_msg ~= "" and string.format('<div class="error">%s</div>', error_msg) or "")
    html = safe_replace(html, "{{original_url}}", ngx.escape_uri(original_url))
    html = safe_replace(html, "{{action_url}}", ngx.var.request_uri)
    
    ngx.header.content_type = "text/html; charset=utf-8"
    ngx.say(html)
    ngx.exit(ngx.HTTP_OK)
end

-- 处理注册请求(测试用)
if ngx.var.uri == config.register_path and ngx.req.get_method() == "GET" then
    local args = ngx.req.get_uri_args()
    local username = args.username
    if utils.is_empty(username) then
        ngx.header.content_type = "text/plain; charset=utf-8"
        ngx.status = 400
        ngx.say("请提供用户名")
        ngx.exit(ngx.HTTP_BAD_REQUEST)
    end
    
    -- 注册新用户,调用otp.register_user生成OTP密钥并存储到数据库
    local ok, otp_secret, message = otp.register_user(username)
    
    -- 成功时返回包含用户和OTP 密钥的HTML页面
    if ok then
        -- 设置Content-Type为text/html,确保浏览器正确解析HTML
        ngx.header.content_type = "text/html; charset=utf-8"
        ngx.say([[
        <!DOCTYPE html>
        <html lang="zh-CN">
        <head>
            <meta charset="UTF-8">
            <title>注册成功</title>
            <style>
                body { font-family: Arial, sans-serif; }
                .container { max-width: 600px; margin: 0 auto; padding: 20px; }
                .success { color: green; font-weight: bold; }
                .back-link { display: inline-block; margin-top: 20px; }
            </style>
        </head>
        <body>
            <div class="container">
                <h2>注册成功</h2>
                <p class="success">用户名: ]] .. username .. [[</p>
                <p class="success">OTP密钥: ]] .. otp_secret .. [[</p>
                <p>请使用此密钥在Google Authenticator中生成验证码</p>
                <a href="/verify" class="back-link">返回登录页面</a>
            </div>
        </body>
        </html>
        ]])
    else
        ngx.header.content_type = "text/plain; charset=utf-8"
        ngx.status = 500
        ngx.say("注册失败: ", message)
        ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR)
    end
    -- 处理完注册请求后,直接退出,不再执行后续代码
    ngx.exit(ngx.HTTP_OK)
end

-- 显示登录弹窗
show_login_popup(nil, ngx.var.request_uri)

# vim /usr/local/nginx/conf/nginx.conf  #nginx主配置文件配置

#gzip  on;
    lua_package_path "/usr/local/lua_core/lib/lua/?.lua;/usr/local/nginx/conf/auth/?.lua;;";
......
    include /usr/local/nginx/conf/conf.d/*.conf;

# vim /usr/local/nginx/conf/conf.d/lua.test.com.conf  #配置一个测试域名,然后本地电脑配置hosts

server {
    listen 80;
    server_name lua.test.com;
    root /opt/web;
    charset utf-8;
    error_log /usr/local/nginx/logs/error.log debug;

    # 主应用入口
    location / {
        default_type 'text/plain';
        # 指定 Lua 处理程序
        content_by_lua_file /usr/local/nginx/conf/auth/main.lua;
    }
    location ~ ^/favicon.ico$ {
        return 404;
        root    /opt/www/favicon;
    }
}

# /usr/local/nginx/sbin/nginx -s reload

2.3 访问验证一下

先注册一个test用户:http://lua.test.com/register?username=test   #如果已经存在test用户了,web页面会提示:

注册失败: 注册失败: 执行SQL失败: Duplicate entry 'test' for key 'username', 错误码: 1062, SQL状态: 23000

再来注册一个不存在的用户:http://lua.test.com/register?username=test1  #下面是web页面显示效果

image.png

下面我们点击下返回登录页面进入验证页面(http://lua.test.com/verify),现在访问任何页面都会页面弹窗需要登录验证除了注册页面,如下图:

image.png

三、增加一个IP白名单的功能

#如果我们有一些固定的出口IP,使用者不希望每天都进行一次登录验证,或者当程序验证有问题频繁弹出验证框的时候,IP白名单功能就很有用了。

3.1 代码编写

# vim /usr/local/nginx/conf/auth/utils.lua  #增加IP处理逻辑

-- 日志函数
function _M.log(level, msg)
    -- 分级日志:支持 debug、info、warn、error 四种日志级别 
    local log_levels = {
        debug = ngx.DEBUG,
        info = ngx.INFO,
        warn = ngx.WARN,
        error = ngx.ERR
    }
    -- 默认级别:若传入的日志级别不合法,默认使用info级别
    local log_level = log_levels[level] or log_levels.info
    --日志前缀:所有日志都会自动添加 [auth] 前缀,方便在 Nginx 日志中快速定位和过滤
    ngx.log(log_level, "[auth] ", msg)
end

-- 读取文件内容
function _M.read_file(path)
        -- 使用io.open(path, "r") 以只读模式打开文件,若文件打开失败返回nil并记录错误
    local file = io.open(path, "r")
    if not file then
        _M.log("error", "Failed to open file: " .. path)
        return nil
    end
        -- 使用file:read("*all") 一次性读取整个文件内容,适合小文件。大文件可能导致内存溢出。
        -- 若处理大文件(如超过 100MB),建议改用循环读取(如 file:read(4096))
        -- 可以使用pcall包裹读取逻辑,防止文件读取过程中出错导致Lua脚本崩溃
    local content = file:read("*all")
    -- 无论文件读取是否成功,file:close() 都会执行(在当前实现中隐含正确关闭,更健壮的写法可使用 finally 确保关闭)
    file:close()
    return content
end

-- 检查IP是否在CIDR范围内
function _M.ip_in_cidr(ip, cidr)
    local ip_parts = {}
    -- 使用 ip:gmatch("%d+") 提取IP地址的四个数字部分,解析IP地址为四部分数字
    for part in ip:gmatch("%d+") do
        table.insert(ip_parts, tonumber(part))
    end

    if #ip_parts ~= 4 then
        return false
    end
        -- 解析CIDR格式 
    local cidr_ip, cidr_mask = cidr:match("^([^/]+)/(%d+)$")
    if not cidr_ip or not cidr_mask then
            -- 若不是CIDR格式,直接比较IP是否相等 
        return ip == cidr
    end
        -- 解析CIDR中的IP部分
    local cidr_parts = {}
    for part in cidr_ip:gmatch("%d+") do
        table.insert(cidr_parts, tonumber(part))
    end

    if #cidr_parts ~= 4 then
        return false
    end
        -- 验证掩码合法性
    cidr_mask = tonumber(cidr_mask)
    if cidr_mask < 0 or cidr_mask > 32 then
        return false
    end
        -- 将IP转换为32位整数
    local ip_num = (ip_parts[1] * 2^24) + (ip_parts[2] * 2^16) + (ip_parts[3] * 2^8) + ip_parts[4]
    local cidr_num = (cidr_parts[1] * 2^24) + (cidr_parts[2] * 2^16) + (cidr_parts[3] * 2^8) + cidr_parts[4]
    local mask = bit.lshift(0xFFFFFFFF, 32 - cidr_mask)
        -- 通过按位与运算判断IP是否在CIDR范围内
    return bit.band(ip_num, mask) == bit.band(cidr_num, mask)
end

#vim /usr/local/nginx/conf/auth/main.lua  #主程序入口文件,增加一些代码,代码分拆到文件中指定的位置哈

-- 获取客户端IP
local function get_client_ip()
    local ip = ngx.var.remote_addr
    
    -- 检查代理头
    if ngx.var.http_x_forwarded_for then
        ip = ngx.var.http_x_forwarded_for:match("^([^,]+)")
    end
    
    return ip
end

-- 检查IP是否在白名单中
local function check_ip_whitelist(ip)
    local whitelist_file = config.whitelist_file
    local content = utils.read_file(whitelist_file)
    
    if not content then
        ngx.log(ngx.WARN, "白名单文件不存在或无法读取: ", whitelist_file)
        return false
    end
    
    -- 检查IP是否在白名单中
    for line in content:gmatch("[^\r\n]+") do
        local whitelisted_ip = line:match("^%s*(.-)%s*$")
        if whitelisted_ip and whitelisted_ip ~= "" then
            if utils.ip_in_cidr(ip, whitelisted_ip) then
                return true
            end
        end
    end
    
    return false
end
-- 主逻辑开始
local client_ip = get_client_ip()
ngx.log(ngx.INFO, "客户端IP: ", client_ip)

-- 检查IP白名单
if check_ip_whitelist(client_ip) then
    ngx.log(ngx.INFO, "IP在白名单中,跳过验证: ", client_ip)
    -- return  -- 继续处理原始请求,return后面一定要跟ngx.exit(ngx.DECLINED),不然你访问什么都是200状态码但是没内容
    -- 声明当前阶段处理完成,让Nginx继续后续处理
    return ngx.exit(ngx.DECLINED)
end

#这里我们要着重了解一下单纯的rerurn、return ngx.exit(ngx.DECLINED)、return ngx.exec(ngx.var.uri)的区别:

return的作用:

终止当前 Lua 代码块的执行:当遇到 return 时,Lua 脚本会立即退出当前函数或代码块,但 不会通知 Nginx 请求处理状态。适用场景:

在 if-else 分支中提前结束逻辑,继续执行后续 Lua 代码,当需要返回值给调用者时(但在 Nginx 处理阶段中很少使用)。

当Nginx 处于 content_by_lua_file 阶段(负责生成响应内容)。由于没有显式生成内容或转发请求,Nginx 最终返回一个空响应(状态码 200,但内容为空)。

return ngx.exit(ngx.DECLINED) 的作用:

显式通知 Nginx 当前阶段处理完成:ngx.DECLINED 是一个特殊状态码,表示 “我不处理这个请求,交给后续模块/阶段继续处理”。适用场景:

在access 阶段(如 access_by_lua_file)中,用于跳过当前验证逻辑,让 Nginx 继续处理请求(如寻找静态文件、执行 proxy_pass等)当需要将请求流程控制权交还给 Nginx 时。

ngx.exec(ngx.var.uri) 的作用:

强制 Nginx 重新路由请求,从头开始处理当前 URI(包括执行 index 指令、寻找静态文件等)。

适用于 content_by_lua_file 场景,确保请求被正确转发到后续处理流程。

#echo "客户端IP" >>/usr/local/nginx/conf/auth/ip_whitelist.txt  #把你要加的IP添加到此白名单文件中,一个IP或者网段单独一行

# vim /opt/web/index.html   #再来一个最简单的测试页面

哈哈少年,恭喜你通过了考验!

3.2 简单测试

浏览器分别访问:http://lua.test.com/index.html和http://lua.test.com/dasdasdas(自己验证结果哈,就不截图了,会分别显示内容和404)

作者:忙碌的柴少 分类:Lua 浏览:149 评论:0
留言列表
发表评论
来宾的头像