- 一、处理异常
- 二、区分不同请求的工作目录
- UUID
- 对 Task 类进行重构
- 三、校验代码的安全性
- 四、阶段性总结
书接上回,我们自己测试没问题,是因为使用了正常数据;万一用户输入的是非法的请求,该咋办?
我们需要处理异常请求,修改整个代码框架。
一、处理异常
为了防止用户输入异常 ID,我们创建 ProblemNotFoundException 异常类来处理。
为了防止用户提交有问题的代码,我们创建 CodeInValidException 异常类来处理。
统一在 catch 处理异常代码。
整理整体代码结构,去除冗余代码,最后 CompileServlet 类代码如下:
@WebServlet("/compile")
public class CompileServlet extends HttpServlet {
static class CompileRequest {
public int id;
public String code;
}
static class CompileResponse {
// 0 表示没问题,1 表示编译出错,2 表示运行异常,3 表示其它错误
public int error;
public String reason;
public String stdout;
}
private ObjectMapper objectMapper = new ObjectMapper();
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
CompileRequest compileRequest = new CompileRequest();
CompileResponse compileResponse = new CompileResponse();
try {
resp.setStatus(200);
resp.setContentType("application/json;charset=utf8");
// 1. 读取请求的正文
String body = readBody(req);
// 类对象,获取类的信息
compileRequest = objectMapper.readValue(body, CompileRequest.class);
// 2. 根据 id 从数据库中查找到题目的详情 - 得到测试用例代码
ProblemDAO problemDAO = new ProblemDAO();
Problem problem = problemDAO.selectOne(compileRequest.id);
// 处理用户输入异常 id,导致查不到题目
if (problem == null) {
// 为了统一处理错误,在这个地方抛出一个异常
throw new ProblemNotFoundException();
}
// testCode 是测试用例的代码
String testCode = problem.getTestCode();
// requestCode 是用户提交的代码
String requestCode = compileRequest.code;
// 3. 把用户提交的代码和测试用例代码,拼接成一个完整的代码
String finalCode = mergeCode(requestCode, testCode);
// 处理用户提交有问题的代码
if (finalCode == null) {
throw new CodeInValidException();
}
// System.out.println(finalCode);
// 4. 创建一个 Task 实例,调用里面的 compileAndRun 来解析编译运行
Task task = new Task();
Question question = new Question();
question.setCode(finalCode);
Answer answer = task.compileAndRun(question);
// 5. 根据 Task 运行的结果,包装成一个 HTTP 响应
compileResponse.error = answer.getError();
compileResponse.reason = answer.getReason();
compileResponse.stdout = answer.getStdout();
} catch (ProblemNotFoundException e) {
// 处理题目没有找到异常
compileResponse.error = 3;
compileResponse.reason = "没有找到指定题目!id = " + compileRequest.id;
} catch (CodeInValidException e) {
// 处理用户提交的代码有问题
compileResponse.error = 3;
compileResponse.reason = "提交的代码不符合要求!";
} finally {
String respString = objectMapper.writeValueAsString(compileResponse);
resp.getWriter().write(respString);
}
}
// 拼接代码
private static String mergeCode(String requestCode, String testCode) {
// 1. 查找 requestCode 最后一个 }
int pos = requestCode.lastIndexOf("}");
if (pos == -1) {
return null;
}
// 2. 截取字符串
String substring = requestCode.substring(0, pos);
// 3. 拼接字符串并返回
return substring + testCode + "\n}";
}
// 通过请求头获取数据,转换成String 返回
private static String readBody(HttpServletRequest req) throws UnsupportedEncodingException {
// 1. 根据请求头里面的 ContentLength 获取到 body 的长度(单位是字节)
int contentLength = req.getContentLength();
// 2. 按照这个长度准备好一个 byte[]
byte[] buffer = new byte[contentLength];
// 3. 通过 req 里面的 getInputStream 方法,获取到 body 的流对象
try (InputStream inputStream = req.getInputStream()) {
// 4. 基于这个流对象,读取内容,然后把内容放到 byte[] 数字中即可
inputStream.read(buffer);
} catch (IOException e) {
e.printStackTrace();
}
// 5. 把这个 byte[] 的内容构造成一个 String,同时设置转换字符集格式
return new String(buffer, "utf8");
}
}
测试一波~
输入错误 id 能够捕捉异常。
二、区分不同请求的工作目录
问题引入
每次有一个请求过来,都需要生成一组临时文件。
如果同一时刻,有 N 个请求一起过来,这些临时文件和目录都是一样的。
此时多个请求之间就会出现 “相互干扰” 的情况(非常类似于线程安全问题)。
这三个请求,里面的题目和提交的代码都是一样的吗?都是不一样的!
因为这是来自三个不同用户的请求。
如果我们使用同一份目录里面的同一份文件,就会出现这种相互干扰的情况!
解决方法
我们需要让每个请求,都有一个自己的目录来存放这些临时文件,不会导致相互干扰。
因此,我们需要让每个请求创建的 WORK_DIR 目录都不相同!这时候就可以使用 “唯一 ID” 来作为目录的名字~
UUID
UUID 是计算机中非常常用的一个概念,表示一个 “全世界都唯一的 id”。每次生成的一个 UUID,会根据一系列算法,来保证这个 UUID 是唯一的。
每个请求,都生成一个唯一的 UUID,进一步创建一个以 UUID 命名的临时目录。最后把生成的临时文件都放在这个临时目录中即可。
对 Task 类进行重构
把开头的一组常量修改成变量。
然后创建一个构造方法,在里面生成 UUID 即可。
完整的 Task 类
// 编译运行
public class Task {
// 通过一组常量来约定临时文件的名字
// 表示所有临时文件所在的目录
private String WORK_DIR = null;
// 约定代码的类名
private String CLASS = null;
// 约定要编译的代码文件名
private String CODE = null;
// 约定存放编译错误信息的文件名
private String COMPILE_ERROR = null;
// 约定存放运行时标准输出的文件名
private String STDOUT = null;
// 存放运行时标准错误的文件名
private String STDERR = null;
public Task() {
// 在 Java 中使用 UUID 这个类,就能够生成一个 UUID
WORK_DIR = "./tmp/" + UUID.randomUUID().toString() + "/";
CLASS = "Solution";
CODE = WORK_DIR + "Solution.java";
COMPILE_ERROR = WORK_DIR + "compileError.txt";
STDOUT = WORK_DIR + "stdout.txt";
STDERR = WORK_DIR + "stderr.txt";
}
// 此类的核心方法。
// 参数:要编译运行的 Java 源代码;
// 返回值:表示编译运行结果。
public Answer compileAndRun(Question question) {
Answer answer = new Answer();
// 0. 准备好用来存放临时文件的目录
File workDir = new File(WORK_DIR);
// 判断是否存在该目录
if (!workDir.exists()) {
// 不存在则创建多级目录.
workDir.mkdirs();
}
// 1. 把 question 中的 code 写入到一个 Solution.java 文件中
FileUtil.writeFile(question.getCode(), CODE);
// 2. 创建子进程,调用 javac 进行编译。编译的时候,需要有一个 .java 文件
// 如果编译出错,javac 就会把错误信息写入到 stderr 里,使用专门的文件来保存:compileError.txt
String compileCmd = String.format("javac -encoding utf8 %s -d %s", CODE, WORK_DIR);
System.out.println("编译时:" + compileCmd);
CommandUtil.run(compileCmd, null, COMPILE_ERROR);
// 如果编译出错,错误信息就被记录到 COMPILE_ERROR 这个文件中。如果没有编译出错,该文件为空。
String compileError = FileUtil.readFile(COMPILE_ERROR);
if (!compileError.equals("")) {
System.out.println("编译出错!");
answer.setError(1);
answer.setReason(compileError);
return answer;
}
// 3. 创建子进程,调用 java 命令执行
// 运行程序的时候,也会把 java 子进程的标准输出和标准错误获取到. stdout.txt, stderr.txt
String runCmd = String.format("java -classpath %s %s", WORK_DIR, CLASS);
System.out.println("运行时:" + runCmd);
CommandUtil.run(runCmd, STDOUT, STDERR);
String runError = FileUtil.readFile(STDERR);
if (!runError.equals("")) {
System.out.println("运行时错误!");
answer.setError(2);
answer.setReason(runError);
return answer;
}
// 4. 父进程获取到刚才的编译执行结果,并打包成 compile.Answer 对象
// 正常编译运行的结果,就通过刚才约定的文件来进行获取
answer.setError(0);
answer.setStdout(FileUtil.readFile(STDOUT));
return answer;
}
public static void main(String[] args) {
Task task = new Task();
// 待编译代码
Question question = new Question();
question.setCode("public class Solution {\n" +
" public static void main(String[] args) {\n" +
" System.out.println(\"hello world\");\n" +
" }\n" +
"}\n");
// 编译运行后的结果
Answer answer = task.compileAndRun(question);
System.out.println(answer);
}
}
单独编译运行 Task 类,我们可以从项目目录的 tmp 文件中,发现已经生成了 UUID 命名的文件。
启动 Tomcat,发现没有生成目录
是因为相对路径的原因。
IDEA 中直接运行 Task 类,这时候的工作目录就是当前 Java 项目所在的目录。
IDEA 通过 SmartTomcat 来运行 Servlet 程序,此时的工作目录就是由 SmartTomcat 控制的。不想由 SmartTomcat 控制,就可以写绝对路径。
所以,当我们使用相对路径指定文件的时候,发现文件找不到,主要是工作目录是啥我们不知道。
我们为代码添加一端监控,查看 SmartTomcat 的工作目录。
// 查看 SmartTomcat 的工作目录
System.out.println("用户工作目录:" + System.getProperty("user.dir"));
重新运行 Tomcat,通过 Postman 发送请求,控制台就会输出工作目录,最后能够在 tmp 文件中找到生成的 UUID 目录。
三、校验代码的安全性
当前代码还存在一个严重的安全性问题。
在线 OJ 系统需要执行一段用户提交的代码,用户提交的代码,可能是存在安全隐患的。
大家可以试试,这段代码在 leetcode 上执行看看什么结果。
有诸多问题需要防范,目前能注意的到有这些:
- Runtime 能够执行一个程序指令,这个比较危险。
- 代码中可能存在一些 “读写操作”,黑客可能直接把一个病毒程序写到你的机器上。
- 代码中如果存在一些 “网络” 操作,也是比较危险的。
解决方法
一个简单粗暴的方法,就是使用一个黑名单,把有危险的代码特性,都放在黑名单中。
在获取到用户提交代码的时候,就查询一个当前是否命中黑名单,如果命中黑名单就直接报错,不去编译执行。
// 黑名单
private boolean checkCodeSafe(String code) {
List<String> blackList = new ArrayList<>();
// 恶意代码
blackList.add("Runtime");
blackList.add("exec");
// 禁止读写文件
blackList.add("java.io");
// 禁止访问网络
blackList.add("java.net");
for (String target : blackList) {
int pos = code.indexOf(target);
if (pos > 0) {
return false;
}
}
return true;
}
四、阶段性总结
- 基于多进程编程的方式,创建了一个 CommandUtil 类,来封装创建进程完成任务的工作。
- 创建了 Task 类,把整个编译运行过程进行了封装。
- 创建了数据库和数据表,设计了题目的存储方式。
- 封装了数据库操作(Problem 和 ProblemDAO)。
- 设计了前后端交互的 API。
- 实现了这些前后端交互的 API。
到这里,我们 online-OJ 项目的服务器后台实现的差不多了。
我们继续实现前端部分,实现 online-OJ 项目的界面。