Relax bounds checks. It was rejecting correct code.

Bounds checks were too strict, causing the bytecode to abort when it shouldn't.
This happened when trying to access the last byte of an array, the verifier
was too conservative and considered to be out of bounds, when in fact it wasn't.

This is an update of the runtime verifier from the bytecode compiler.
0.96
Török Edvin 16 years ago
parent 789b5255d2
commit daad92ace3
  1. 4
      ChangeLog
  2. 2
      libclamav/c++/ClamBCModule.h
  3. 140
      libclamav/c++/ClamBCRTChecks.cpp
  4. 29
      libclamav/c++/bytecode2llvm.cpp

@ -1,3 +1,7 @@
Thu May 13 12:41:24 EEST 2010 (edwin)
------------------------------------
* libclamav/c++: Relax bounds checks. Was rejecting correct code.
Wed May 12 19:10:39 CEST 2010 (acab)
------------------------------------
* docs/man: add clamav.milter.conf.5

@ -5,6 +5,6 @@ namespace llvm {
class Pass;
}
namespace ClamBCModule {
void stop(const char *msg, llvm::Function* F, llvm::Instruction* I);
void stop(const char *msg, llvm::Function* F, llvm::Instruction* I=0);
}
llvm::Pass *createClamBCRTChecks();

@ -21,6 +21,7 @@
*/
#define DEBUG_TYPE "clambc-rtcheck"
#include "ClamBCModule.h"
#include "ClamBCDiagnostics.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SCCIterator.h"
@ -125,7 +126,7 @@ namespace {
Value *V = CI->getCalledValue()->stripPointerCasts();
Function *F = dyn_cast<Function>(V);
if (!F) {
printLocation(errs(), CI);
printLocation(CI, true);
errs() << "Could not determine call target\n";
valid = 0;
continue;
@ -171,7 +172,7 @@ namespace {
if (isa<PointerType>(FTy->getParamType(i))) {
Value *Ptr = CI->getOperand(i+1);
if (i+1 >= FTy->getNumParams()) {
printLocation(errs(), CI);
printLocation(CI, false);
errs() << "Call to external function with pointer parameter last cannot be analyzed\n";
errs() << *CI << "\n";
valid = 0;
@ -179,7 +180,7 @@ namespace {
}
Value *Size = CI->getOperand(i+2);
if (!Size->getType()->isIntegerTy()) {
printLocation(errs(), CI);
printLocation(CI, false);
errs() << "Pointer argument must be followed by integer argument representing its size\n";
errs() << *CI << "\n";
valid = 0;
@ -195,7 +196,7 @@ namespace {
if (!valid) {
DEBUG(F.dump());
ClamBCModule::stop("Verification found errors!", &F, 0);
ClamBCModule::stop("Verification found errors!", &F);
// replace function with call to abort
std::vector<const Type*>args;
FunctionType* abrtTy = FunctionType::get(
@ -407,7 +408,34 @@ namespace {
return BoundsMap[Base] = V;
}
bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I)
MDNode *getLocation(Instruction *I, bool &Approximate, unsigned MDDbgKind)
{
Approximate = false;
if (MDNode *Dbg = I->getMetadata(MDDbgKind))
return Dbg;
if (!MDDbgKind)
return 0;
Approximate = true;
BasicBlock::iterator It = I;
while (It != I->getParent()->begin()) {
--It;
if (MDNode *Dbg = It->getMetadata(MDDbgKind))
return Dbg;
}
BasicBlock *BB = I->getParent();
while ((BB = BB->getUniquePredecessor())) {
It = BB->end();
while (It != BB->begin()) {
--It;
if (MDNode *Dbg = It->getMetadata(MDDbgKind))
return Dbg;
}
}
return 0;
}
bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I,
bool strict)
{
if (isa<SCEVCouldNotCompute>(Idx) && isa<SCEVCouldNotCompute>(Limit)) {
errs() << "Could not compute the index and the limit!: \n" << *I << "\n";
@ -425,22 +453,28 @@ namespace {
BasicBlock::iterator It = I;
BasicBlock *newBB = SplitBlock(BB, &*It, this);
PHINode *PN;
unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
//verifyFunction(*BB->getParent());
if (!AbrtBB) {
std::vector<const Type*>args;
FunctionType* abrtTy = FunctionType::get(
Type::getVoidTy(BB->getContext()),args,false);
args.push_back(Type::getInt32Ty(BB->getContext()));
// FunctionType* rterrTy = FunctionType::get(
// Type::getInt32Ty(BB->getContext()),args,false);
FunctionType* rterrTy = FunctionType::get(
Type::getInt32Ty(BB->getContext()),args,false);
Constant *func_abort =
BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy);
// Constant *func_rterr =
// BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy);
Constant *func_rterr =
BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy);
AbrtBB = BasicBlock::Create(BB->getContext(), "", BB->getParent());
PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),"",
AbrtBB);
// CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
if (MDDbgKind) {
CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB);
RtErrCall->setCallingConv(CallingConv::C);
RtErrCall->setTailCall(true);
RtErrCall->setDoesNotThrow(true);
}
CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB);
AbrtC->setCallingConv(CallingConv::C);
AbrtC->setTailCall(true);
@ -452,16 +486,24 @@ namespace {
} else {
PN = cast<PHINode>(AbrtBB->begin());
}
unsigned MDDbgKind = I->getContext().getMDKindID("dbg");
unsigned locationid = 0;
if (MDNode *Dbg = I->getMetadata(MDDbgKind)) {
bool Approximate;
if (MDNode *Dbg = getLocation(I, Approximate, MDDbgKind)) {
DILocation Loc(Dbg);
locationid = Loc.getLineNumber() << 8;
unsigned col = Loc.getColumnNumber();
if (col > 255)
if (col > 254)
col = 254;
if (Approximate)
col = 255;
locationid |= col;
// Loc.getFilename();
} else {
static int wcounters = 100000;
locationid = (wcounters++)<<8;
/*errs() << "fake location: " << (locationid>>8) << "\n";
I->dump();
I->getParent()->dump();*/
}
PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()),
locationid), BB);
@ -475,7 +517,9 @@ namespace {
//verifyFunction(*BB->getParent());
Value *LimitV = expander.expandCodeFor(Limit, Limit->getType(), TI);
//verifyFunction(*BB->getParent());
Value *Cond = new ICmpInst(TI, ICmpInst::ICMP_ULT, IdxV, LimitV);
Value *Cond = new ICmpInst(TI, strict ?
ICmpInst::ICMP_ULT :
ICmpInst::ICMP_ULE, IdxV, LimitV);
//verifyFunction(*BB->getParent());
BranchInst::Create(newBB, AbrtBB, Cond, TI);
TI->eraseFromParent();
@ -525,31 +569,6 @@ namespace {
}
return false;
}
static void printValue(llvm::raw_ostream &Out, llvm::Value *V) {
std::string DisplayName;
std::string Type;
unsigned Line;
std::string File;
std::string Dir;
if (!getLocationInfo(V, DisplayName, Type, Line, File, Dir)) {
Out << *V << "\n";
return;
}
Out << "'" << DisplayName << "' (" << File << ":" << Line << ")";
}
static void printLocation(llvm::raw_ostream &Out, llvm::Instruction *I) {
if (MDNode *N = I->getMetadata("dbg")) {
DILocation Loc(N);
Out << Loc.getFilename() << ":" << Loc.getLineNumber();
if (unsigned Col = Loc.getColumnNumber()) {
Out << ":" << Col;
}
Out << ": ";
return;
}
Out << *I << ":\n";
}
bool validateAccess(Value *Pointer, Value *Length, Instruction *I)
{
@ -560,13 +579,13 @@ namespace {
// get bounds
Value *Bounds = getPointerBounds(SBase);
if (!Bounds) {
printLocation(errs(), I);
printLocation(I, true);
errs() << "no bounds for base ";
printValue(errs(), SBase);
printValue(SBase);
errs() << " while checking access to ";
printValue(errs(), Pointer);
printValue(Pointer);
errs() << " of length ";
printValue(errs(), Length);
printValue(Length);
errs() << "\n";
return false;
@ -574,17 +593,17 @@ namespace {
if (CallInst *CI = dyn_cast<CallInst>(Base->stripPointerCasts())) {
if (I->getParent() == CI->getParent()) {
printLocation(errs(), I);
printLocation(I, true);
errs() << "no null pointer check of pointer ";
printValue(errs(), Base);
printValue(Base, false, true);
errs() << " obtained by function call";
errs() << " before use in same block\n";
return false;
}
if (!checkCondition(CI, I)) {
printLocation(errs(), I);
printLocation(I, true);
errs() << "no null pointer check of pointer ";
printValue(errs(), Base);
printValue(Base, false, true);
errs() << " obtained by function call";
errs() << " before use\n";
return false;
@ -603,8 +622,16 @@ namespace {
DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " <<
*Length << "\n");
if (OffsetP == Limit)
return true;
if (OffsetP == Limit) {
printLocation(I, true);
errs() << "OffsetP == Limit: " << *OffsetP << "\n";
errs() << " while checking access to ";
printValue(Pointer);
errs() << " of length ";
printValue(Length);
errs() << "\n";
return false;
}
if (SLen == Limit) {
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OffsetP)) {
@ -614,22 +641,29 @@ namespace {
errs() << "SLen == Limit: " << *SLen << "\n";
errs() << " while checking access to " << *Pointer << " of length "
<< *Length << " at " << *I << "\n";
return false;//TODO: insert abort
return false;
}
bool valid = true;
SLen = SE->getAddExpr(OffsetP, SLen);
// check that offset + slen <= limit;
// umax(offset+slen, limit) == limit is a sufficient (but not necessary
// condition)
const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit);
if (MaxL != Limit) {
DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n");
return insertCheck(SLen, Limit, I);
valid &= insertCheck(SLen, Limit, I, false);
}
//TODO: nullpointer check
const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit);
if (Max == Limit)
return true;
return valid;
DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n");
return insertCheck(OffsetP, Limit, I);
// check that offset < limit
valid &= insertCheck(OffsetP, Limit, I, true);
return valid;
}
bool validateAccess(Value *Pointer, unsigned size, Instruction *I)

@ -25,6 +25,8 @@
#include <sys/time.h>
#endif
#include "ClamBCModule.h"
#include "ClamBCDiagnostics.h"
#include "llvm/Analysis/DebugInfo.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/PostOrderIterator.h"
@ -73,6 +75,7 @@
#include <csetjmp>
#include <new>
#include <cerrno>
#include <string>
#include "llvm/Config/config.h"
#if !ENABLE_THREADS
@ -2044,3 +2047,29 @@ void stop(const char *msg, llvm::Function* F, llvm::Instruction* I)
llvm::errs() << msg << "\n";
}
}
void printValue(llvm::Value *V, bool a, bool b) {
std::string DisplayName;
std::string Type;
unsigned Line;
std::string File;
std::string Dir;
if (!getLocationInfo(V, DisplayName, Type, Line, File, Dir)) {
errs() << *V << "\n";
return;
}
errs() << "'" << DisplayName << "' (" << File << ":" << Line << ")";
}
void printLocation(llvm::Instruction *I, bool a, bool b) {
if (MDNode *N = I->getMetadata("dbg")) {
DILocation Loc(N);
errs() << Loc.getFilename() << ":" << Loc.getLineNumber();
if (unsigned Col = Loc.getColumnNumber()) {
errs() << ":" << Col;
}
errs() << ": ";
return;
}
errs() << *I << ":\n";
}

Loading…
Cancel
Save