

TensorFlow.js:基于MobileNet的浏览器图片过滤插件
source link: https://blog.dev4eos.com/2019/05/24/tensorflowjs-mobilenet-filter/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

TensorFlow.js:基于MobileNet的浏览器图片过滤插件
有些人可能天生对蛇之类的图片感到敏感,又或者避免小孩子上网浏览到不健康的内容,这个时候我们可能需要对网页的图片建立一个前置过滤系统。
在以前做到这个可能很麻烦很费资源,当然,这一切在有了tf.js后要做到是很方便的。
MobileNet
Google提出的移动端模型MobileNet,其核心是采用了可分解的depthwise separable convolution,其不仅可以降低模型计算复杂度,而且可以大大降低模型大小。
根据tfjs的demo,我们下载一个训练好的模型。index.js
const MOBILENET_MODEL_PATH =
// tslint:disable-next-line:max-line-length
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';
demo里用的是0.25大小的模型,准确性有点问题。
https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.75_224/model.json
我没找到这个在哪可以下载完整的,路径猜测发现0.75的存在,于是我写了个脚本把里面的分片文件一个个下载下来了
var fs = require('fs');
var fetch = require('node-fetch');
var path = require('path');
var config = fs.readFileSync('mobilenet-0.75/model.json', 'utf-8');
config = JSON.parse(config);
// var model = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.75_224/model.json';
(async () => {
// var configResp = await fetch(model);
// var config = await configResp.json();
var gPath = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.75_224/';
// console.log(gPath);
// return;
for (let index = 0; index < config.weightsManifest.length; index++) {
const weightsManifest = config.weightsManifest[index];
for (let indexc = 0; indexc < weightsManifest.paths.length; indexc++) {
const path = weightsManifest.paths[indexc];
var realFile = gPath+path;
var resp = await fetch(realFile);
var data = await resp.buffer();
fs.writeFileSync('./mobilenet-0.75/'+path, data);
}
}
})();
background.js
准备一个分类的接口,接收一个图片地址,返回改图片的分类列表
const MOBILENET_MODEL_PATH = // tslint:disable-next-line:max-line-length https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json
'mobilenet-0.75/model.json';
const IMAGE_SIZE = 224;
const TOPK_PREDICTIONS = 10;
let mobilenet;
(async () => {
mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH); // Warmup the model
mobilenet.predict(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3])).dispose();
})();
接收一个Image对象,返回潜在Top分类,详见
async function predict(imgElement) {
status('Predicting...'); // The first start time includes the time it takes to extract the image
// from the HTML and preprocess it, in additon to the predict() call.
const startTime1 = performance.now(); // The second start time excludes the extraction and preprocessing and
// includes only the predict() call.
let startTime2;
const logits = tf.tidy(() => {
// tf.browser.fromPixels() returns a Tensor from an image element.
const img = tf.browser.fromPixels(imgElement).toFloat();
const offset = tf.scalar(127.5); // Normalize the image from [0, 255] to [-1, 1].
const normalized = img.sub(offset).div(offset); // Reshape to a single-element batch so we can pass it to predict.
const batched = normalized.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);
startTime2 = performance.now(); // Make a prediction through mobilenet.
return mobilenet.predict(batched);
}); // Convert logits to probabilities and class names.
const classes = await getTopKClasses(logits, TOPK_PREDICTIONS);
const totalTime1 = performance.now() - startTime1;
const totalTime2 = performance.now() - startTime2;
// status("Done in ".concat(Math.floor(totalTime1), " ms ") + "(not including preprocessing: ".concat(Math.floor(totalTime2), " ms)")); // Show the classes in the DOM.
// showResults(imgElement, classes);
return classes;
}
/**
* Computes the probabilities of the topK classes given logits by computing
* softmax to get probabilities and then sorting the probabilities.
* @param logits Tensor representing the logits from MobileNet.
* @param topK The number of top predictions to show.
*/
async function getTopKClasses(logits, topK) {
const values = await logits.data();
const valuesAndIndices = [];
for (let i = 0; i < values.length; i++) {
valuesAndIndices.push({
value: values[i],
index: i
});
}
valuesAndIndices.sort((a, b) => {
return b.value - a.value;
});
const topkValues = new Float32Array(topK);
const topkIndices = new Int32Array(topK);
for (let i = 0; i < topK; i++) {
topkValues[i] = valuesAndIndices[i].value;
topkIndices[i] = valuesAndIndices[i].index;
}
const topClassesAndProbs = [];
for (let i = 0; i < topkIndices.length; i++) {
topClassesAndProbs.push({
className: _imagenet_classes.IMAGENET_CLASSES[topkIndices[i]],
probability: topkValues[i]
});
}
return topClassesAndProbs;
}
图片URL转base64,插件manifest.json需要 "permissions": ["http://*/*", "https://*/*" ],
function getDataUrl(srcUrl, cb) {
var canvas = document.createElement('canvas'),
context;
var tmpImage = new Image;
tmpImage.src = srcUrl;
tmpImage.onload = function() {
canvas.width = tmpImage.width;
canvas.height = tmpImage.height;
context = canvas.getContext('2d');
context.drawImage(tmpImage, 0, 0);
cb(canvas.toDataURL());
}
}
监听页面请求,获取图片URL base64重设大小,拿到分类结果后判断是否包含要过滤的分类,并发送结果给请求的页面。
// 过滤分类的关键词
const keywords = ['snake', 'cobra'];
const IMAGE_SIZE = 224;
chrome.runtime.onMessage.addListener(function(request, sender, sendResponseA) {
console.log('onMessage', request.src);
var hitBlack = false;
function sendResponse(status, classes){
chrome.tabs.sendMessage(sender.tab.id, {
action: "ret",
src: request.src,
status: status,
classes: classes
}, function(response) {
});
}
getDataUrl(request.src, function(res){
console.log(res);
let img = document.createElement('img');
img.src = res;
img.width = IMAGE_SIZE;
img.height = IMAGE_SIZE;
img.onerror = function(){
sendResponse(false);
}
img.onload = function(){
(async () => {
try{
var classes = await predict(img);
for (var index = 0; index < classes.length; index++) {
const classe = classes[index];
for (var ck = 0; ck < keywords.length; ck++) {
var keyword = keywords[ck];
if(classe.className.indexOf(keyword) > -1){
hitBlack = true;
break;
}
}
if(hitBlack){
break;
}
}
console.log('classes', classes);
sendResponse(hitBlack, classes);
}catch(e){
sendResponse(false);
}
})();
};
});
});
content_script
内容页js需要先把图片隐藏透明度0,页面所有图并监听发现新的图片,然后都发送请求给background.js
const eleCache = {};
function findImages(){
// console.log('findImages');
var images = document.getElementsByTagName('img');
for (var index = 0; index < images.length; index++) {
var image = images[index];
var classify = image.getAttribute('classify');
if(classify != null) continue;
if(image.getAttribute('fetching') != null) continue;
if(image.width > 50){
image.style = "opacity: 0;";
eleCache[image.src] = eleCache[image.src] || [];
eleCache[image.src].push(image);
image.setAttribute('fetching', 1);
try{
chrome.extension.sendMessage({
src: image.src
}, function(hitBlack) {
});
}catch(e){}
}
}
}
// findImages();
var timeStart = Date.now();
var longRun = null;
var timer = setInterval(function(){
var timeLeft = Date.now() - timeStart;
if(timeLeft > 5000){
// longRun = setInterval(function(){
// findImages();
// }, 3000)
return clearInterval(timer);
}
findImages();
}, 50);
var MutationObserver = window.MutationObserver || window.WebKitMutationObserver || window.MozMutationObserver
window.onload = function(){
findImages();
var target = document.body;
var observer = new MutationObserver(function(mutations) {
findImages();
});
var config = { attributes: true, childList: true, characterData: true, subtree: true}
observer.observe(target, config);
}
监听分类结果并显示不需要屏蔽的图片
chrome.runtime.onMessage.addListener(function(request, sender, sendResponseA) {
if(request.action){
// console.log('request', request);
var src = request.src;
var classes = request.classes;
var els = eleCache[src];
if(els && !request.status){
els.forEach(function(el){
el.style = "opacity: 1;";
});
console.log('image', src, 'notinblacklist')
}else{
console.log('image', src, 'blocked')
}
if(els) els.forEach(function(el){
el.setAttribute('classify', 1);
// classes
var alts = [];
if(classes) classes.forEach(function(a){
alts.push(a.className)
})
el.setAttribute('cats', JSON.stringify(classes));
el.setAttribute('alt', alts.join("\n"));
});
}
})
以上一个简单基于tfjs的前端图片过滤系统就实现了。
此外我们还可以增加一个UI来设置要过滤的分类。或者增加其他的一些模型进来完善各种特殊的过滤需求。
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK