![](/style/images/good.png)
![](/style/images/bad.png)
refine patch(more flexible) · findstr/silly@59941ab · GitHub
source link: https://github.com/findstr/silly/commit/59941ab0f318408c581f14fc8cef5ca55b0e63ec
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
refine patch(more flexible) · findstr/silly@59941ab · GitHub
@@ -0,0 +1,19 @@
local console = require "sys.console"
local run = require "patch.run"
console {
addr = "127.0.0.1:2345",
cmd = {
foo = function()
run.foo()
local a, b = run.dump()
return string.format("a:%s,b:%s,step:%s", a, b, tostring(step))
end,
bar = function()
run.bar()
local a, b = run.dump()
return string.format("a:%s,b:%s,step:%s", a, b, tostring(step))
end
}
}
@@ -0,0 +1,20 @@
local M = {}
step = 3
local a, b = 1, 1
function M.foo()
step = 4
a = a + step
b = b + step
end
function M.bar()
b = b + 1
end
function M.dump()
return a, b
end
return M
@@ -0,0 +1,20 @@
local patch = require "sys.patch"
local run = require "patch.run"
local ENV = setmetatable({}, {__index = _ENV})
local fix = loadfile("examples/patch/fix.lua", "bt", ENV)()
local P = patch:create()
local up1 = P:collectupval(run)
local up2 = P:collectupval(fix)
local absent = P:join(up2, up1)
print(absent[1], absent[2])
debug.upvaluejoin(up2.foo.val, up2.foo.upvals.b.idx, up2.dump.val, up2.dump.upvals.b.idx)
print("hotfix ok")
for k, v in pairs(fix) do
run[k] = v
end
for k, v in pairs(ENV) do
if type(v) == "function" or not _ENV[k] then
_ENV[k] = v
end
end
@@ -0,0 +1,17 @@
local M = {}
local a, b = 1, 1
function M.foo()
a = a + 1
end
function M.bar()
b = b + 1
end
function M.dump()
return a, b
end
return M
@@ -14,7 +14,7 @@ cluster() {
}
module() {
./silly --lualib_path="lualib/?.lua" --lualib_cpath="luaclib/?.so" --bootstrap="examples/$1.lua"
./silly --lualib_path="lualib/?.lua;examples/?.lua" --lualib_cpath="luaclib/?.so" --bootstrap="examples/$1.lua"
}
all() {
@@ -28,7 +28,7 @@ local desc = {
"CPUINFO: Show system time and user time statistics. [CPUINFO]",
"SOCKET: Show socket detail information. [SOCKET]",
"TASK: Show all task status and traceback. [TASK]",
"PATCH: Hot patch the code. [PATCH <modulename> <filename>]",
"INJECT: INJECT code. [INJECT <path>]",
"DEBUG: Enter Debug mode. [DEBUG]",
}
@@ -37,17 +37,6 @@ local console = {}
local envmt = {__index = _ENV}
local function _patch(module, filename)
local ENV = {}
setmetatable(ENV, envmt)
local runm = require(module)
local newm = assert(loadfile(fixfile, "bt", ENV))()
assert(runm and type(runm) == "table")
assert(newm and type(newm) == "table")
patch(ENV, runm, newm)
end
function console.help()
return desc
end
@@ -172,18 +161,20 @@ function console.info()
return concat(tbl, "\r\n")
end
function console.patch(_, module, filename)
if not module then
return "ERR lost the module file name"
elseif not filename == 0 then
return "ERR lost the filename"
function console.inject(_, filepath)
if not filepath == 0 then
return "ERR lost the filepath"
end
local ENV = setmetatable({}, envmt)
local ok, err = pcall(loadfile, filepath, bt, ENV)
if ok then
ok, err = pcall(err)
end
local ok, err = pcall(_patch, module, filename)
local fmt = "Patch module:%s function:%s by:%s %s"
local fmt = "Inject file:%s %s"
if ok then
return format(fmt, module, fix, "Success")
return format(fmt, filepath, "Success")
else
return format(fmt, module, fix, err)
return format(fmt, filepath, err)
end
end
@@ -228,6 +219,7 @@ return function (config)
core.log("console come in:", addr)
local param = {}
local dat = {}
socket.write(fd, "Hello\n")
socket.write(fd, prompt)
while true do
local l = socket.readline(fd)
@@ -1,110 +1,137 @@
local type = type
local pairs = pairs
local assert = assert
local dgetinfo = debug.getinfo
local dgetupval = debug.getupvalue
local dupvaljoin = debug.upvaluejoin
local concat = table.concat
local setmetatable = setmetatable
local getinfo = debug.getinfo
local getupval = debug.getupvalue
local setupvalue = debug.setupvalue
local upvalid = debug.upvalueid
local upvaljoin = debug.upvaluejoin
local function collectup(func, tbl, unique)
local info = dgetinfo(func, "u")
local M = {}
local mt = {__index = M}
local function collectfn(fn, upvals, unique)
local info = getinfo(fn, "u")
for i = 1, info.nups do
local up
local name, val = dgetupval(func, i)
local name, val = getupval(fn, i)
local utype = type(val)
if utype == "function" then
local upval = unique[val]
if not upval then
upval = {}
unique[val] = upval
collectup(val, upval, unique)
collectfn(val, upval, unique)
end
up = {
idx = i,
utype = utype,
val = val,
up = upval
upid = upvalid(fn, i),
upvals = upval
}
else
up = {
idx = i,
utype = utype,
val = val,
upid = upvalid(fn, i),
}
end
tbl[name] = up
upvals[name] = up
end
end
function M.create()
return setmetatable({
collected = {},
valjoined = {},
fnjoined = {},
}, mt)
end
local function join_val(joined, nfv, nup, rfv, rup)
assert(type(nfv) == "function")
assert(type(rfv) == "function")
for name, nu in pairs(nup) do
local ru = rup[name]
if ru then
local nidx = nu.idx
local ridx = ru.idx
assert(ru.utype == nu.utype or not nu.val or not ru.val)
if nu.utype == "function" then
local v = nu.val
if not joined[v] then
joined[v] = true
join_val(joined, v, nu.up, ru.val, ru.up)
end
else
dupvaljoin(nfv, nidx, rfv, ridx)
end
function M:collectupval(f_or_t)
local upvals = {}
local t = type(f_or_t)
local unique = self.collected
if t == "table" then
for name, fn in pairs(f_or_t) do
local x = {}
collectfn(fn, x, unique)
upvals[name] = {
val = fn,
upvals = x
}
end
else
collectfn(f_or_t, upvals, unique)
end
return upvals
end
local function join_fn(joined, rfv, rup, nfv, nup)
assert(type(nfv) == "function")
assert(type(rfv) == "function")
for name, ru in pairs(rup) do
local nu = nup[name]
if nu then
local nidx = nu.idx
local ridx = ru.idx
if nu.utype == "function" then
assert(ru.utype == nu.utype)
local v = nu.val
local function joinval(f1, up1, f2, up2, joined, path, absent)
local n = 0
local depth = #path + 1
for name, uv1 in pairs(up1) do
path[depth] = name
local uv2 = up2[name]
if uv2 then
local idx1 = uv1.idx
local idx2 = uv2.idx
assert(uv1.utype == uv2.utype or not uv1.val or not uv2.val)
if uv1.utype == "function" then
assert(uv2.utype == "function")
local v = uv1.val
if not joined[v] then
joined[v] = true
join_fn(joined, ru.val, ru.up, nu.val, nu.up)
n = n + joinval(v, uv1.upvals,
uv2.val, uv2.upvals,
joined, path, absent)
end
dupvaljoin(rfv, ridx, nfv, nidx)
else
upvaljoin(f1, idx1, f2, idx2)
end
elseif name == "_ENV" then
setupvalue(f1, uv1.idx, _ENV)
absent[#absent + 1] = concat(path, ".")
else
absent[#absent + 1] = concat(path, ".")
end
path[depth] = nil
end
return n
end
return function(newenv, runm, newm)
local unique = {}
local val_joined = {}
local fn_joined = {}
--join the runm to newm
for fn, nfv in pairs(newm) do
local rfv = runm[fn]
if rfv then
local ru = {}
local nu = {}
collectup(rfv, ru, unique)
collectup(nfv, nu, unique)
join_val(val_joined, nfv, nu, rfv, ru)
join_fn(fn_joined, rfv, ru, nfv, nu)
end
end
--replace runm to newm
for fn, nfv in pairs(newm) do
runm[fn] = nfv
end
--fix _ENV
for k, v in pairs(newenv) do
if not _ENV[k] then
_ENV[k] = v
function M:join(f1, up1, f2, up2)
local n
local path = {"$"}
local absent = {}
local joined = self.valjoined
if type(f1) == "table" then
local up1, up2 = f1, up1
n = 0
for name, uv1 in pairs(up1) do
local uv2 = up2[name]
if uv2 then
path[2] = name
n = n + joinval(uv1.val, uv1.upvals,
uv2.val, uv2.upvals,
joined, path, absent)
path[2] = nil
end
end
else
assert(type(f1) == "function")
assert(type(f2) == "function")
path[2] = "#"
n = joinval(f1, up1, f2, up2, joined, path, absent)
path[2] = nil
end
return
return absent
end
return M
@@ -1,8 +1,23 @@
local core = require "sys.core"
local patch = require "sys.patch"
local testaux = require "testaux"
local function fix(P, ENV, M1, M2, skip)
local up1 = P:collectupval(M1)
local up2 = P:collectupval(M2)
local absent = P:join(up2, up1)
testaux.asserteq(absent[1], skip, "test absent upvalue")
for name, fn1 in pairs(M2) do
M1[name] = fn1
end
for k, v in pairs(ENV) do
if not _ENV[k] or type(v) == "function" then
_ENV[k] = v
end
end
return up1, up2
end
return function()
local function case1(P)
local M1 = load([[
local testup1 = 3
local M = {}
@@ -53,10 +68,69 @@ return function()
print("test patch closure")
testaux.asserteq(M1.testfn2(), 7, "old module")
testaux.asserteq(M1.testfn2(), 11, "old module")
patch(ENV, M1, M2)
fix(P, ENV, M1, M2, "$.testfn2.testfn2.testfn2._ENV")
testaux.asserteq(M1.testfn2(), 19, "new module")
testaux.asserteq(M1.testfn2(), 27, "new module")
end
local function case2(P)
local M1 = load([[
local testup1 = 3
local M = {}
local function testfn2()
testup1 = testup1 + 2
end
local function testfn3()
return function()
testfn2()
testfn2()
end
end
local testfn2 = testfn3()
function M.testfn2()
testfn2()
return testup1
end
return M
]])()
local ENV = {}
local M2 = load([[
local step = 3
local testup1 = 3
local M = {}
local function testfn2()
testup1 = testup1 + step
end
local function testfn3()
step = step + 1
return function()
testfn2()
testfn2()
end
end
local testfn2 = testfn3()
function M.testfn2()
testfn2()
return testup1
end
return M
]], nil, "t", ENV)()
print("test patch closure")
testaux.asserteq(M1.testfn2(), 7, "old module")
testaux.asserteq(M1.testfn2(), 11, "old module")
fix(P, ENV, M1, M2, "$.testfn2.testfn2.testfn2.step")
testaux.asserteq(M1.testfn2(), 19, "new module")
testaux.asserteq(M1.testfn2(), 27, "new module")
end
local function case3(P)
local M1 = load([[
local core = require "sys.core"
local M = {}
@@ -84,11 +158,12 @@ return function()
local foo
local timer_foo
function timer_foo()
if not timer_foo then
return
end
foo = "world"
print("timer new")
if timer_foo then
core.timeout(500, timer_foo)
end
core.timeout(500, timer_foo)
end
function M.timer_foo()
timer_foo()
@@ -107,9 +182,102 @@ return function()
M1.timer_foo()
core.sleep(1000)
testaux.asserteq(M1.get_foo(), "hello", "old timer")
patch(ENV, M1, M2)
local up1, up2 = fix(P, ENV, M1, M2, nil)
local uv1 = up1.timer_foo.upvals.timer_foo
local uv2 = up2.timer_foo.upvals.timer_foo
testaux.asserteq(up1.timer_foo.upvals.timer_foo.upid,
uv1.upvals.timer_foo.upid, "test upvalueid")
debug.setupvalue(uv1.val, uv1.upvals.timer_foo.idx, uv2.upvals.timer_foo.val)
core.sleep(1000)
testaux.asserteq(M1.get_foo(), "world", "new timer")
M1.stop()
print("test patch success")
end
local function case4(P)
local M1 = load([[
local M = {}
local a,b = 3,4
function M.foo()
a = a + 1
return a
end
function M.foo2()
b = b + 2
return b
end
function M.bar()
return a, b
end
return M
]], nil, "t", _ENV)()
local ENV = setmetatable({}, {__index = _ENV})
local M2 = load([[
local M = {}
local a,b = 0,0
function M.foo()
a = a + 1
b = b + 1
return a,b
end
function M.foo2()
b = b + 3
return b
end
function M.bar()
return a, b
end
return M
]], nil, "t", ENV)()
testaux.asserteq(M1.foo(), 4, "old foo")
testaux.asserteq(M1.foo2(), 6, "old foo2")
local a, b = M1.bar()
testaux.asserteq(a, 4, "old bar")
testaux.asserteq(b, 6, "old bar")
local up1, up2 = fix(P, ENV, M1, M2, "$.foo.b")
debug.upvaluejoin(up2.foo.val, up2.foo.upvals.b.idx, up2.bar.val, up2.bar.upvals.b.idx)
local a, b = M1.foo()
testaux.asserteq(a, 5, "new foo")
testaux.asserteq(b, 7, "new foo")
testaux.asserteq(M1.foo2(), 10, "new foo")
local a, b = M1.bar()
testaux.asserteq(a, 5, "new foo")
testaux.asserteq(b, 10, "new foo")
end
local function case5(P)
local M1 = load([[
local M = {}
function M.foo()
end
return M
]], nil, "t", _ENV)()
local ENV = setmetatable({}, {__index = _ENV})
local M2 = load([[
local M = {}
bar = 3
function M.foo()
bar = bar + 1
end
return M
]], nil, "t", ENV)()
fix(P, ENV, M1, M2, "$.foo._ENV")
testaux.asserteq(bar, 3, "global variable")
M1.foo()
testaux.asserteq(bar, 4, "global variable")
end
return function()
local P = patch:create()
case1(P)
case2(P)
case3(P)
case4(P)
case5(P)
end
0 comments
on commit 59941ab
Please sign in to comment.
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK