3

refine patch(more flexible) · findstr/silly@59941ab · GitHub

 2 years ago
source link: https://github.com/findstr/silly/commit/59941ab0f318408c581f14fc8cef5ca55b0e63ec
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

refine patch(more flexible) · findstr/silly@59941ab · GitHub

Permalink

This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Browse files

refine patch(more flexible)

findstr committed on Oct 28, 2021

1 parent 045140c commit 59941ab0f318408c581f14fc8cef5ca55b0e63ec
Showing with 357 additions and 94 deletions.

@@ -0,0 +1,19 @@

local console = require "sys.console"

local run = require "patch.run"

console {

addr = "127.0.0.1:2345",

cmd = {

foo = function()

run.foo()

local a, b = run.dump()

return string.format("a:%s,b:%s,step:%s", a, b, tostring(step))

end,

bar = function()

run.bar()

local a, b = run.dump()

return string.format("a:%s,b:%s,step:%s", a, b, tostring(step))

end

}

}

@@ -0,0 +1,20 @@

local M = {}

step = 3

local a, b = 1, 1

function M.foo()

step = 4

a = a + step

b = b + step

end

function M.bar()

b = b + 1

end

function M.dump()

return a, b

end

return M

@@ -0,0 +1,20 @@

local patch = require "sys.patch"

local run = require "patch.run"

local ENV = setmetatable({}, {__index = _ENV})

local fix = loadfile("examples/patch/fix.lua", "bt", ENV)()

local P = patch:create()

local up1 = P:collectupval(run)

local up2 = P:collectupval(fix)

local absent = P:join(up2, up1)

print(absent[1], absent[2])

debug.upvaluejoin(up2.foo.val, up2.foo.upvals.b.idx, up2.dump.val, up2.dump.upvals.b.idx)

print("hotfix ok")

for k, v in pairs(fix) do

run[k] = v

end

for k, v in pairs(ENV) do

if type(v) == "function" or not _ENV[k] then

_ENV[k] = v

end

end

@@ -0,0 +1,17 @@

local M = {}

local a, b = 1, 1

function M.foo()

a = a + 1

end

function M.bar()

b = b + 1

end

function M.dump()

return a, b

end

return M

@@ -14,7 +14,7 @@ cluster() {

}

module() {

./silly --lualib_path="lualib/?.lua" --lualib_cpath="luaclib/?.so" --bootstrap="examples/$1.lua"

./silly --lualib_path="lualib/?.lua;examples/?.lua" --lualib_cpath="luaclib/?.so" --bootstrap="examples/$1.lua"

}

all() {

@@ -28,7 +28,7 @@ local desc = {

"CPUINFO: Show system time and user time statistics. [CPUINFO]",

"SOCKET: Show socket detail information. [SOCKET]",

"TASK: Show all task status and traceback. [TASK]",

"PATCH: Hot patch the code. [PATCH <modulename> <filename>]",

"INJECT: INJECT code. [INJECT <path>]",

"DEBUG: Enter Debug mode. [DEBUG]",

}

@@ -37,17 +37,6 @@ local console = {}

local envmt = {__index = _ENV}

local function _patch(module, filename)

local ENV = {}

setmetatable(ENV, envmt)

local runm = require(module)

local newm = assert(loadfile(fixfile, "bt", ENV))()

assert(runm and type(runm) == "table")

assert(newm and type(newm) == "table")

patch(ENV, runm, newm)

end

function console.help()

return desc

end

@@ -172,18 +161,20 @@ function console.info()

return concat(tbl, "\r\n")

end

function console.patch(_, module, filename)

if not module then

return "ERR lost the module file name"

elseif not filename == 0 then

return "ERR lost the filename"

function console.inject(_, filepath)

if not filepath == 0 then

return "ERR lost the filepath"

end

local ENV = setmetatable({}, envmt)

local ok, err = pcall(loadfile, filepath, bt, ENV)

if ok then

ok, err = pcall(err)

end

local ok, err = pcall(_patch, module, filename)

local fmt = "Patch module:%s function:%s by:%s %s"

local fmt = "Inject file:%s %s"

if ok then

return format(fmt, module, fix, "Success")

return format(fmt, filepath, "Success")

else

return format(fmt, module, fix, err)

return format(fmt, filepath, err)

end

end

@@ -228,6 +219,7 @@ return function (config)

core.log("console come in:", addr)

local param = {}

local dat = {}

socket.write(fd, "Hello\n")

socket.write(fd, prompt)

while true do

local l = socket.readline(fd)

@@ -1,110 +1,137 @@

local type = type

local pairs = pairs

local assert = assert

local dgetinfo = debug.getinfo

local dgetupval = debug.getupvalue

local dupvaljoin = debug.upvaluejoin

local concat = table.concat

local setmetatable = setmetatable

local getinfo = debug.getinfo

local getupval = debug.getupvalue

local setupvalue = debug.setupvalue

local upvalid = debug.upvalueid

local upvaljoin = debug.upvaluejoin

local function collectup(func, tbl, unique)

local info = dgetinfo(func, "u")

local M = {}

local mt = {__index = M}

local function collectfn(fn, upvals, unique)

local info = getinfo(fn, "u")

for i = 1, info.nups do

local up

local name, val = dgetupval(func, i)

local name, val = getupval(fn, i)

local utype = type(val)

if utype == "function" then

local upval = unique[val]

if not upval then

upval = {}

unique[val] = upval

collectup(val, upval, unique)

collectfn(val, upval, unique)

end

up = {

idx = i,

utype = utype,

val = val,

up = upval

upid = upvalid(fn, i),

upvals = upval

}

else

up = {

idx = i,

utype = utype,

val = val,

upid = upvalid(fn, i),

}

end

tbl[name] = up

upvals[name] = up

end

end

function M.create()

return setmetatable({

collected = {},

valjoined = {},

fnjoined = {},

}, mt)

end

local function join_val(joined, nfv, nup, rfv, rup)

assert(type(nfv) == "function")

assert(type(rfv) == "function")

for name, nu in pairs(nup) do

local ru = rup[name]

if ru then

local nidx = nu.idx

local ridx = ru.idx

assert(ru.utype == nu.utype or not nu.val or not ru.val)

if nu.utype == "function" then

local v = nu.val

if not joined[v] then

joined[v] = true

join_val(joined, v, nu.up, ru.val, ru.up)

end

else

dupvaljoin(nfv, nidx, rfv, ridx)

end

function M:collectupval(f_or_t)

local upvals = {}

local t = type(f_or_t)

local unique = self.collected

if t == "table" then

for name, fn in pairs(f_or_t) do

local x = {}

collectfn(fn, x, unique)

upvals[name] = {

val = fn,

upvals = x

}

end

else

collectfn(f_or_t, upvals, unique)

end

return upvals

end

local function join_fn(joined, rfv, rup, nfv, nup)

assert(type(nfv) == "function")

assert(type(rfv) == "function")

for name, ru in pairs(rup) do

local nu = nup[name]

if nu then

local nidx = nu.idx

local ridx = ru.idx

if nu.utype == "function" then

assert(ru.utype == nu.utype)

local v = nu.val

local function joinval(f1, up1, f2, up2, joined, path, absent)

local n = 0

local depth = #path + 1

for name, uv1 in pairs(up1) do

path[depth] = name

local uv2 = up2[name]

if uv2 then

local idx1 = uv1.idx

local idx2 = uv2.idx

assert(uv1.utype == uv2.utype or not uv1.val or not uv2.val)

if uv1.utype == "function" then

assert(uv2.utype == "function")

local v = uv1.val

if not joined[v] then

joined[v] = true

join_fn(joined, ru.val, ru.up, nu.val, nu.up)

n = n + joinval(v, uv1.upvals,

uv2.val, uv2.upvals,

joined, path, absent)

end

dupvaljoin(rfv, ridx, nfv, nidx)

else

upvaljoin(f1, idx1, f2, idx2)

end

elseif name == "_ENV" then

setupvalue(f1, uv1.idx, _ENV)

absent[#absent + 1] = concat(path, ".")

else

absent[#absent + 1] = concat(path, ".")

end

path[depth] = nil

end

return n

end

return function(newenv, runm, newm)

local unique = {}

local val_joined = {}

local fn_joined = {}

--join the runm to newm

for fn, nfv in pairs(newm) do

local rfv = runm[fn]

if rfv then

local ru = {}

local nu = {}

collectup(rfv, ru, unique)

collectup(nfv, nu, unique)

join_val(val_joined, nfv, nu, rfv, ru)

join_fn(fn_joined, rfv, ru, nfv, nu)

end

end

--replace runm to newm

for fn, nfv in pairs(newm) do

runm[fn] = nfv

end

--fix _ENV

for k, v in pairs(newenv) do

if not _ENV[k] then

_ENV[k] = v

function M:join(f1, up1, f2, up2)

local n

local path = {"$"}

local absent = {}

local joined = self.valjoined

if type(f1) == "table" then

local up1, up2 = f1, up1

n = 0

for name, uv1 in pairs(up1) do

local uv2 = up2[name]

if uv2 then

path[2] = name

n = n + joinval(uv1.val, uv1.upvals,

uv2.val, uv2.upvals,

joined, path, absent)

path[2] = nil

end

end

else

assert(type(f1) == "function")

assert(type(f2) == "function")

path[2] = "#"

n = joinval(f1, up1, f2, up2, joined, path, absent)

path[2] = nil

end

return

return absent

end

return M

@@ -1,8 +1,23 @@

local core = require "sys.core"

local patch = require "sys.patch"

local testaux = require "testaux"

local function fix(P, ENV, M1, M2, skip)

local up1 = P:collectupval(M1)

local up2 = P:collectupval(M2)

local absent = P:join(up2, up1)

testaux.asserteq(absent[1], skip, "test absent upvalue")

for name, fn1 in pairs(M2) do

M1[name] = fn1

end

for k, v in pairs(ENV) do

if not _ENV[k] or type(v) == "function" then

_ENV[k] = v

end

end

return up1, up2

end

return function()

local function case1(P)

local M1 = load([[

local testup1 = 3

local M = {}

@@ -53,10 +68,69 @@ return function()

print("test patch closure")

testaux.asserteq(M1.testfn2(), 7, "old module")

testaux.asserteq(M1.testfn2(), 11, "old module")

patch(ENV, M1, M2)

fix(P, ENV, M1, M2, "$.testfn2.testfn2.testfn2._ENV")

testaux.asserteq(M1.testfn2(), 19, "new module")

testaux.asserteq(M1.testfn2(), 27, "new module")

end

local function case2(P)

local M1 = load([[

local testup1 = 3

local M = {}

local function testfn2()

testup1 = testup1 + 2

end

local function testfn3()

return function()

testfn2()

testfn2()

end

end

local testfn2 = testfn3()

function M.testfn2()

testfn2()

return testup1

end

return M

]])()

local ENV = {}

local M2 = load([[

local step = 3

local testup1 = 3

local M = {}

local function testfn2()

testup1 = testup1 + step

end

local function testfn3()

step = step + 1

return function()

testfn2()

testfn2()

end

end

local testfn2 = testfn3()

function M.testfn2()

testfn2()

return testup1

end

return M

]], nil, "t", ENV)()

print("test patch closure")

testaux.asserteq(M1.testfn2(), 7, "old module")

testaux.asserteq(M1.testfn2(), 11, "old module")

fix(P, ENV, M1, M2, "$.testfn2.testfn2.testfn2.step")

testaux.asserteq(M1.testfn2(), 19, "new module")

testaux.asserteq(M1.testfn2(), 27, "new module")

end

local function case3(P)

local M1 = load([[

local core = require "sys.core"

local M = {}

@@ -84,11 +158,12 @@ return function()

local foo

local timer_foo

function timer_foo()

if not timer_foo then

return

end

foo = "world"

print("timer new")

if timer_foo then

core.timeout(500, timer_foo)

end

core.timeout(500, timer_foo)

end

function M.timer_foo()

timer_foo()

@@ -107,9 +182,102 @@ return function()

M1.timer_foo()

core.sleep(1000)

testaux.asserteq(M1.get_foo(), "hello", "old timer")

patch(ENV, M1, M2)

local up1, up2 = fix(P, ENV, M1, M2, nil)

local uv1 = up1.timer_foo.upvals.timer_foo

local uv2 = up2.timer_foo.upvals.timer_foo

testaux.asserteq(up1.timer_foo.upvals.timer_foo.upid,

uv1.upvals.timer_foo.upid, "test upvalueid")

debug.setupvalue(uv1.val, uv1.upvals.timer_foo.idx, uv2.upvals.timer_foo.val)

core.sleep(1000)

testaux.asserteq(M1.get_foo(), "world", "new timer")

M1.stop()

print("test patch success")

end

local function case4(P)

local M1 = load([[

local M = {}

local a,b = 3,4

function M.foo()

a = a + 1

return a

end

function M.foo2()

b = b + 2

return b

end

function M.bar()

return a, b

end

return M

]], nil, "t", _ENV)()

local ENV = setmetatable({}, {__index = _ENV})

local M2 = load([[

local M = {}

local a,b = 0,0

function M.foo()

a = a + 1

b = b + 1

return a,b

end

function M.foo2()

b = b + 3

return b

end

function M.bar()

return a, b

end

return M

]], nil, "t", ENV)()

testaux.asserteq(M1.foo(), 4, "old foo")

testaux.asserteq(M1.foo2(), 6, "old foo2")

local a, b = M1.bar()

testaux.asserteq(a, 4, "old bar")

testaux.asserteq(b, 6, "old bar")

local up1, up2 = fix(P, ENV, M1, M2, "$.foo.b")

debug.upvaluejoin(up2.foo.val, up2.foo.upvals.b.idx, up2.bar.val, up2.bar.upvals.b.idx)

local a, b = M1.foo()

testaux.asserteq(a, 5, "new foo")

testaux.asserteq(b, 7, "new foo")

testaux.asserteq(M1.foo2(), 10, "new foo")

local a, b = M1.bar()

testaux.asserteq(a, 5, "new foo")

testaux.asserteq(b, 10, "new foo")

end

local function case5(P)

local M1 = load([[

local M = {}

function M.foo()

end

return M

]], nil, "t", _ENV)()

local ENV = setmetatable({}, {__index = _ENV})

local M2 = load([[

local M = {}

bar = 3

function M.foo()

bar = bar + 1

end

return M

]], nil, "t", ENV)()

fix(P, ENV, M1, M2, "$.foo._ENV")

testaux.asserteq(bar, 3, "global variable")

M1.foo()

testaux.asserteq(bar, 4, "global variable")

end

return function()

local P = patch:create()

case1(P)

case2(P)

case3(P)

case4(P)

case5(P)

end

0 comments on commit 59941ab

Please sign in to comment.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK