diff --git a/ChangeLog b/ChangeLog index ba8439bc7..59cd3fe68 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,7 @@ +Thu May 13 12:41:24 EEST 2010 (edwin) +------------------------------------ + * libclamav/c++: Relax bounds checks. Was rejecting correct code. + Wed May 12 19:10:39 CEST 2010 (acab) ------------------------------------ * docs/man: add clamav.milter.conf.5 diff --git a/libclamav/c++/ClamBCModule.h b/libclamav/c++/ClamBCModule.h index cf8339b95..07cc3d6a2 100644 --- a/libclamav/c++/ClamBCModule.h +++ b/libclamav/c++/ClamBCModule.h @@ -5,6 +5,6 @@ namespace llvm { class Pass; } namespace ClamBCModule { - void stop(const char *msg, llvm::Function* F, llvm::Instruction* I); + void stop(const char *msg, llvm::Function* F, llvm::Instruction* I=0); } llvm::Pass *createClamBCRTChecks(); diff --git a/libclamav/c++/ClamBCRTChecks.cpp b/libclamav/c++/ClamBCRTChecks.cpp index 9afb1e396..697a937d9 100644 --- a/libclamav/c++/ClamBCRTChecks.cpp +++ b/libclamav/c++/ClamBCRTChecks.cpp @@ -21,6 +21,7 @@ */ #define DEBUG_TYPE "clambc-rtcheck" #include "ClamBCModule.h" +#include "ClamBCDiagnostics.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SCCIterator.h" @@ -125,7 +126,7 @@ namespace { Value *V = CI->getCalledValue()->stripPointerCasts(); Function *F = dyn_cast(V); if (!F) { - printLocation(errs(), CI); + printLocation(CI, true); errs() << "Could not determine call target\n"; valid = 0; continue; @@ -171,7 +172,7 @@ namespace { if (isa(FTy->getParamType(i))) { Value *Ptr = CI->getOperand(i+1); if (i+1 >= FTy->getNumParams()) { - printLocation(errs(), CI); + printLocation(CI, false); errs() << "Call to external function with pointer parameter last cannot be analyzed\n"; errs() << *CI << "\n"; valid = 0; @@ -179,7 +180,7 @@ namespace { } Value *Size = CI->getOperand(i+2); if (!Size->getType()->isIntegerTy()) { - printLocation(errs(), CI); + printLocation(CI, false); errs() << "Pointer argument must be followed by integer argument representing its size\n"; errs() << *CI << "\n"; valid = 0; @@ -195,7 +196,7 @@ namespace { if (!valid) { DEBUG(F.dump()); - ClamBCModule::stop("Verification found errors!", &F, 0); + ClamBCModule::stop("Verification found errors!", &F); // replace function with call to abort std::vectorargs; FunctionType* abrtTy = FunctionType::get( @@ -407,7 +408,34 @@ namespace { return BoundsMap[Base] = V; } - bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I) + MDNode *getLocation(Instruction *I, bool &Approximate, unsigned MDDbgKind) + { + Approximate = false; + if (MDNode *Dbg = I->getMetadata(MDDbgKind)) + return Dbg; + if (!MDDbgKind) + return 0; + Approximate = true; + BasicBlock::iterator It = I; + while (It != I->getParent()->begin()) { + --It; + if (MDNode *Dbg = It->getMetadata(MDDbgKind)) + return Dbg; + } + BasicBlock *BB = I->getParent(); + while ((BB = BB->getUniquePredecessor())) { + It = BB->end(); + while (It != BB->begin()) { + --It; + if (MDNode *Dbg = It->getMetadata(MDDbgKind)) + return Dbg; + } + } + return 0; + } + + bool insertCheck(const SCEV *Idx, const SCEV *Limit, Instruction *I, + bool strict) { if (isa(Idx) && isa(Limit)) { errs() << "Could not compute the index and the limit!: \n" << *I << "\n"; @@ -425,22 +453,28 @@ namespace { BasicBlock::iterator It = I; BasicBlock *newBB = SplitBlock(BB, &*It, this); PHINode *PN; + unsigned MDDbgKind = I->getContext().getMDKindID("dbg"); //verifyFunction(*BB->getParent()); if (!AbrtBB) { std::vectorargs; FunctionType* abrtTy = FunctionType::get( Type::getVoidTy(BB->getContext()),args,false); args.push_back(Type::getInt32Ty(BB->getContext())); -// FunctionType* rterrTy = FunctionType::get( -// Type::getInt32Ty(BB->getContext()),args,false); + FunctionType* rterrTy = FunctionType::get( + Type::getInt32Ty(BB->getContext()),args,false); Constant *func_abort = BB->getParent()->getParent()->getOrInsertFunction("abort", abrtTy); -// Constant *func_rterr = -// BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy); + Constant *func_rterr = + BB->getParent()->getParent()->getOrInsertFunction("bytecode_rt_error", rterrTy); AbrtBB = BasicBlock::Create(BB->getContext(), "", BB->getParent()); PN = PHINode::Create(Type::getInt32Ty(BB->getContext()),"", AbrtBB); -// CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB); + if (MDDbgKind) { + CallInst *RtErrCall = CallInst::Create(func_rterr, PN, "", AbrtBB); + RtErrCall->setCallingConv(CallingConv::C); + RtErrCall->setTailCall(true); + RtErrCall->setDoesNotThrow(true); + } CallInst* AbrtC = CallInst::Create(func_abort, "", AbrtBB); AbrtC->setCallingConv(CallingConv::C); AbrtC->setTailCall(true); @@ -452,16 +486,24 @@ namespace { } else { PN = cast(AbrtBB->begin()); } - unsigned MDDbgKind = I->getContext().getMDKindID("dbg"); unsigned locationid = 0; - if (MDNode *Dbg = I->getMetadata(MDDbgKind)) { + bool Approximate; + if (MDNode *Dbg = getLocation(I, Approximate, MDDbgKind)) { DILocation Loc(Dbg); locationid = Loc.getLineNumber() << 8; unsigned col = Loc.getColumnNumber(); - if (col > 255) + if (col > 254) + col = 254; + if (Approximate) col = 255; locationid |= col; // Loc.getFilename(); + } else { + static int wcounters = 100000; + locationid = (wcounters++)<<8; + /*errs() << "fake location: " << (locationid>>8) << "\n"; + I->dump(); + I->getParent()->dump();*/ } PN->addIncoming(ConstantInt::get(Type::getInt32Ty(BB->getContext()), locationid), BB); @@ -475,7 +517,9 @@ namespace { //verifyFunction(*BB->getParent()); Value *LimitV = expander.expandCodeFor(Limit, Limit->getType(), TI); //verifyFunction(*BB->getParent()); - Value *Cond = new ICmpInst(TI, ICmpInst::ICMP_ULT, IdxV, LimitV); + Value *Cond = new ICmpInst(TI, strict ? + ICmpInst::ICMP_ULT : + ICmpInst::ICMP_ULE, IdxV, LimitV); //verifyFunction(*BB->getParent()); BranchInst::Create(newBB, AbrtBB, Cond, TI); TI->eraseFromParent(); @@ -525,31 +569,6 @@ namespace { } return false; } - static void printValue(llvm::raw_ostream &Out, llvm::Value *V) { - std::string DisplayName; - std::string Type; - unsigned Line; - std::string File; - std::string Dir; - if (!getLocationInfo(V, DisplayName, Type, Line, File, Dir)) { - Out << *V << "\n"; - return; - } - Out << "'" << DisplayName << "' (" << File << ":" << Line << ")"; - } - - static void printLocation(llvm::raw_ostream &Out, llvm::Instruction *I) { - if (MDNode *N = I->getMetadata("dbg")) { - DILocation Loc(N); - Out << Loc.getFilename() << ":" << Loc.getLineNumber(); - if (unsigned Col = Loc.getColumnNumber()) { - Out << ":" << Col; - } - Out << ": "; - return; - } - Out << *I << ":\n"; - } bool validateAccess(Value *Pointer, Value *Length, Instruction *I) { @@ -560,13 +579,13 @@ namespace { // get bounds Value *Bounds = getPointerBounds(SBase); if (!Bounds) { - printLocation(errs(), I); + printLocation(I, true); errs() << "no bounds for base "; - printValue(errs(), SBase); + printValue(SBase); errs() << " while checking access to "; - printValue(errs(), Pointer); + printValue(Pointer); errs() << " of length "; - printValue(errs(), Length); + printValue(Length); errs() << "\n"; return false; @@ -574,17 +593,17 @@ namespace { if (CallInst *CI = dyn_cast(Base->stripPointerCasts())) { if (I->getParent() == CI->getParent()) { - printLocation(errs(), I); + printLocation(I, true); errs() << "no null pointer check of pointer "; - printValue(errs(), Base); + printValue(Base, false, true); errs() << " obtained by function call"; errs() << " before use in same block\n"; return false; } if (!checkCondition(CI, I)) { - printLocation(errs(), I); + printLocation(I, true); errs() << "no null pointer check of pointer "; - printValue(errs(), Base); + printValue(Base, false, true); errs() << " obtained by function call"; errs() << " before use\n"; return false; @@ -603,8 +622,16 @@ namespace { DEBUG(dbgs() << "Checking access to " << *Pointer << " of length " << *Length << "\n"); - if (OffsetP == Limit) - return true; + if (OffsetP == Limit) { + printLocation(I, true); + errs() << "OffsetP == Limit: " << *OffsetP << "\n"; + errs() << " while checking access to "; + printValue(Pointer); + errs() << " of length "; + printValue(Length); + errs() << "\n"; + return false; + } if (SLen == Limit) { if (const SCEVConstant *SC = dyn_cast(OffsetP)) { @@ -614,22 +641,29 @@ namespace { errs() << "SLen == Limit: " << *SLen << "\n"; errs() << " while checking access to " << *Pointer << " of length " << *Length << " at " << *I << "\n"; - return false;//TODO: insert abort + return false; } + bool valid = true; + SLen = SE->getAddExpr(OffsetP, SLen); + // check that offset + slen <= limit; + // umax(offset+slen, limit) == limit is a sufficient (but not necessary + // condition) const SCEV *MaxL = SE->getUMaxExpr(SLen, Limit); if (MaxL != Limit) { DEBUG(dbgs() << "MaxL != Limit: " << *MaxL << ", " << *Limit << "\n"); - return insertCheck(SLen, Limit, I); + valid &= insertCheck(SLen, Limit, I, false); } //TODO: nullpointer check const SCEV *Max = SE->getUMaxExpr(OffsetP, Limit); if (Max == Limit) - return true; + return valid; DEBUG(dbgs() << "Max != Limit: " << *Max << ", " << *Limit << "\n"); - return insertCheck(OffsetP, Limit, I); + // check that offset < limit + valid &= insertCheck(OffsetP, Limit, I, true); + return valid; } bool validateAccess(Value *Pointer, unsigned size, Instruction *I) diff --git a/libclamav/c++/bytecode2llvm.cpp b/libclamav/c++/bytecode2llvm.cpp index e6b73a801..c5d3048ad 100644 --- a/libclamav/c++/bytecode2llvm.cpp +++ b/libclamav/c++/bytecode2llvm.cpp @@ -25,6 +25,8 @@ #include #endif #include "ClamBCModule.h" +#include "ClamBCDiagnostics.h" +#include "llvm/Analysis/DebugInfo.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/PostOrderIterator.h" @@ -73,6 +75,7 @@ #include #include #include +#include #include "llvm/Config/config.h" #if !ENABLE_THREADS @@ -2044,3 +2047,29 @@ void stop(const char *msg, llvm::Function* F, llvm::Instruction* I) llvm::errs() << msg << "\n"; } } + +void printValue(llvm::Value *V, bool a, bool b) { + std::string DisplayName; + std::string Type; + unsigned Line; + std::string File; + std::string Dir; + if (!getLocationInfo(V, DisplayName, Type, Line, File, Dir)) { + errs() << *V << "\n"; + return; + } + errs() << "'" << DisplayName << "' (" << File << ":" << Line << ")"; +} + +void printLocation(llvm::Instruction *I, bool a, bool b) { + if (MDNode *N = I->getMetadata("dbg")) { + DILocation Loc(N); + errs() << Loc.getFilename() << ":" << Loc.getLineNumber(); + if (unsigned Col = Loc.getColumnNumber()) { + errs() << ":" << Col; + } + errs() << ": "; + return; + } + errs() << *I << ":\n"; +}