#include "pipeline.h" #include "sir/util/cloning.h" #include "sir/util/irtools.h" #include "sir/util/matching.h" #include namespace seq { using namespace ir; namespace { const std::string prefetchModule = "std.bio.prefetch"; const std::string builtinModule = "std.bio.builtin"; const std::string alignModule = "std.bio.align"; const std::string seqModule = "std.bio.seq"; bool isParallel(PipelineFlow *p) { for (const auto &stage : *p) { if (stage.isParallel()) return true; } return false; } BodiedFunc *makeStageWrapperFunc(PipelineFlow::Stage *stage, Func *callee, types::Type *inputType) { auto *M = callee->getModule(); std::vector argTypes = {inputType}; std::vector argNames = {"0"}; int i = 1; for (auto *arg : *stage) { if (arg) { argTypes.push_back(arg->getType()); argNames.push_back(std::to_string(i++)); } } auto *funcType = M->getFuncType(util::getReturnType(callee), argTypes); auto *wrapperFunc = M->Nr("__stage_wrapper"); wrapperFunc->realize(funcType, argNames); // reorder arguments std::vector args; auto it = wrapperFunc->arg_begin(); ++it; for (auto *arg : *stage) { if (arg) { args.push_back(M->Nr(*it++)); } else { args.push_back(M->Nr(wrapperFunc->arg_front())); } } wrapperFunc->setBody(util::series(M->Nr(util::call(callee, args)))); return wrapperFunc; } // check for a regular func, or a flow-instr containing a func, // which is caused by a partial function. BodiedFunc *getStageFunc(PipelineFlow::Stage &stage) { auto *callee = stage.getCallee(); if (auto *f = cast(util::getFunc(callee))) { return f; } else if (auto *s = cast(callee)) { if (auto *f = cast(util::getFunc(s->getValue()))) { return f; } } return nullptr; } Value *replaceStageFunc(PipelineFlow::Stage &stage, Func *schedFunc, util::CloneVisitor &cv) { auto *callee = stage.getCallee(); auto *M = callee->getModule(); if (auto *f = cast(util::getFunc(callee))) { return M->Nr(schedFunc); } else if (auto *s = cast(callee)) { if (auto *f = cast(util::getFunc(s->getValue()))) { auto *clone = cast(cv.clone(s)); clone->setValue(M->Nr(schedFunc)); return clone; } } seqassert(0, "invalid stage func replacement"); return nullptr; } } // namespace const std::string PipelineSubstitutionOptimization::KEY = "seq-pipeline-subst-opt"; const std::string PipelinePrefetchOptimization::KEY = "seq-pipeline-prefetch-opt"; const std::string PipelineInterAlignOptimization::KEY = "seq-pipeline-inter-align-opt"; /* * Substitution optimizations */ void PipelineSubstitutionOptimization::handle(PipelineFlow *p) { auto *M = p->getModule(); PipelineFlow::Stage *prev = nullptr; auto it = p->begin(); while (it != p->end()) { if (prev) { { auto *f1 = util::getStdlibFunc(prev->getCallee(), "kmers", "bio"); auto *f2 = util::getStdlibFunc(it->getCallee(), "revcomp", "bio"); if (f1 && f2) { auto *funcType = cast(f1->getType()); auto *genType = cast(funcType->getReturnType()); auto *seqType = funcType->front(); auto *kmerType = genType->getBase(); auto *kmersRevcompFunc = M->getOrRealizeFunc( "_kmers_revcomp", {seqType, M->getIntType()}, {kmerType}, builtinModule); seqassert(kmersRevcompFunc && util::getReturnType(kmersRevcompFunc)->is(genType), "invalid reverse complement function"); cast(prev->getCallee())->setVar(kmersRevcompFunc); if (it->isParallel()) prev->setParallel(); it = p->erase(it); continue; } } { auto *f1 = util::getStdlibFunc(prev->getCallee(), "kmers_with_pos", "bio"); auto *f2 = util::getStdlibFunc(it->getCallee(), "revcomp_with_pos", "bio"); if (f1 && f2) { auto *funcType = cast(f1->getType()); auto *genType = cast(funcType->getReturnType()); auto *seqType = funcType->front(); auto *kmerType = cast(genType->getBase())->back().getType(); auto *kmersRevcompWithPosFunc = M->getOrRealizeFunc("_kmers_revcomp_with_pos", {seqType, M->getIntType()}, {kmerType}, builtinModule); seqassert(kmersRevcompWithPosFunc && util::getReturnType(kmersRevcompWithPosFunc)->is(genType), "invalid pos reverse complement function"); cast(prev->getCallee())->setVar(kmersRevcompWithPosFunc); if (it->isParallel()) prev->setParallel(); it = p->erase(it); continue; } } { auto *f1 = util::getStdlibFunc(prev->getCallee(), "kmers", "bio"); auto *f2 = util::getStdlibFunc(it->getCallee(), "canonical", "bio"); if (f1 && f2 && util::isConst(prev->back(), 1)) { auto *funcType = cast(f1->getType()); auto *genType = cast(funcType->getReturnType()); auto *seqType = funcType->front(); auto *kmerType = genType->getBase(); auto *kmersCanonicalFunc = M->getOrRealizeFunc("_kmers_canonical", {seqType}, {kmerType}, builtinModule); seqassert(kmersCanonicalFunc && util::getReturnType(kmersCanonicalFunc)->is(genType), "invalid canonical kmers function"); cast(prev->getCallee())->setVar(kmersCanonicalFunc); prev->erase(prev->end() - 1); // remove step argument if (it->isParallel()) prev->setParallel(); it = p->erase(it); continue; } } { auto *f1 = util::getStdlibFunc(prev->getCallee(), "kmers_with_pos", "bio"); auto *f2 = util::getStdlibFunc(it->getCallee(), "canonical_with_pos", "bio"); if (f1 && f2 && util::isConst(prev->back(), 1)) { auto *funcType = cast(f1->getType()); auto *genType = cast(funcType->getReturnType()); auto *seqType = funcType->front(); auto *kmerType = cast(genType->getBase())->back().getType(); auto *kmersCanonicalWithPosFunc = M->getOrRealizeFunc( "_kmers_canonical_with_pos", {seqType}, {kmerType}, builtinModule); seqassert(kmersCanonicalWithPosFunc && util::getReturnType(kmersCanonicalWithPosFunc)->is(genType), "invalid pos canonical kmers function"); cast(prev->getCallee())->setVar(kmersCanonicalWithPosFunc); prev->erase(prev->end() - 1); // remove step argument if (it->isParallel()) prev->setParallel(); it = p->erase(it); continue; } } } prev = &*it; ++it; } } /* * Prefetch optimization */ struct PrefetchFunctionTransformer : public util::Operator { void handle(ReturnInstr *x) override { auto *M = x->getModule(); x->replaceAll(M->Nr(x->getValue(), /*final=*/true)); } void handle(CallInstr *x) override { auto *func = cast(util::getFunc(x->getCallee())); if (!func || func->getUnmangledName() != Module::GETITEM_MAGIC_NAME || x->numArgs() != 2) return; auto *M = x->getModule(); Value *self = x->front(); Value *key = x->back(); types::Type *selfType = self->getType(); types::Type *keyType = key->getType(); Func *prefetchFunc = M->getOrRealizeMethod(selfType, "__prefetch__", {selfType, keyType}); if (!prefetchFunc) return; Value *prefetch = util::call(prefetchFunc, {self, key}); auto *yield = M->Nr(); auto *replacement = util::series(prefetch, yield); util::CloneVisitor cv(M); auto *clone = cv.clone(x); see(clone); // avoid infinite loop on clone x->replaceAll(M->Nr(replacement, clone)); } }; void PipelinePrefetchOptimization::handle(PipelineFlow *p) { if (isParallel(p)) return; auto *M = p->getModule(); PrefetchFunctionTransformer pft; PipelineFlow::Stage *prev = nullptr; util::CloneVisitor cv(M); for (auto it = p->begin(); it != p->end(); ++it) { if (auto *func = getStageFunc(*it)) { if (!it->isGenerator() && util::hasAttribute(func, "std.bio.builtin.prefetch")) { // transform prefetch'ing function auto *clone = cast(cv.forceClone(func)); util::setReturnType(clone, M->getGeneratorType(util::getReturnType(clone))); clone->setGenerator(); clone->getBody()->accept(pft); // make sure the arguments are in the correct order auto *inputType = prev->getOutputElementType(); clone = makeStageWrapperFunc(&*it, clone, inputType); auto *coroType = cast(clone->getType()); // vars auto *statesType = M->getArrayType(coroType->getReturnType()); seqassert((SCHED_WIDTH_PREFETCH & (SCHED_WIDTH_PREFETCH - 1)) == 0, "not a power of 2"); // power of 2 auto *width = M->getInt(SCHED_WIDTH_PREFETCH); auto *init = M->Nr(); auto *parent = cast(getParentFunc()); seqassert(parent, "not in a function"); auto *filled = util::makeVar(M->getInt(0), init, parent); auto *next = util::makeVar(M->getInt(0), init, parent); auto *states = util::makeVar( M->Nr(statesType, SCHED_WIDTH_PREFETCH), init, parent); insertBefore(init); // scheduler auto *intType = M->getIntType(); auto *intPtrType = M->getPointerType(intType); std::vector stageArgTypes; std::vector stageArgs; for (auto *arg : *it) { if (arg) { stageArgs.push_back(arg); stageArgTypes.push_back(arg->getType()); } } auto *extraArgs = util::makeTuple(stageArgs, M); std::vector argTypes = { inputType, coroType, statesType, intPtrType, intPtrType, intType, extraArgs->getType()}; Func *schedFunc = M->getOrRealizeFunc("_dynamic_coroutine_scheduler", argTypes, {}, prefetchModule); seqassert(schedFunc, "could not realize scheduler function"); PipelineFlow::Stage stage(replaceStageFunc(*it, schedFunc, cv), {nullptr, M->Nr(clone), states, M->Nr(next->getVar()), M->Nr(filled->getVar()), width, extraArgs}, /*generator=*/true, /*parallel=*/false); // drain Func *drainFunc = M->getOrRealizeFunc("_dynamic_coroutine_scheduler_drain", {statesType, intType}, {}, prefetchModule); std::vector args = {states, filled}; std::vector drainStages = { {util::call(drainFunc, args), {}, /*generator=*/true, /*parallel=*/false}}; *it = stage; if (std::distance(it, p->end()) == 1 && !util::getReturnType(func)->is(M->getVoidType())) { Func *dummyFunc = M->getOrRealizeFunc("_dummy_prefetch_terminal_stage", {stage.getOutputElementType()}, {}, prefetchModule); seqassert(dummyFunc, "could not realize dummy prefetch"); p->push_back({M->Nr(dummyFunc), {nullptr}, /*generator=*/false, /*parallel=*/false}); } for (++it; it != p->end(); ++it) { drainStages.push_back(cv.clone(*it)); } auto *drain = util::series(M->Nr(next->getVar(), M->getInt(0)), M->Nr(drainStages)); insertAfter(drain); LOG_REALIZE("[prefetch] {}", *p); break; // at most one prefetch transformation per pipeline } } prev = &*it; } } /* * Inter-sequence alignment optimization */ struct InterAlignTypes { types::Type *seq; // plain sequence type ('seq') types::Type *cigar; // CIGAR string type ('CIGAR') types::Type *align; // alignment result type ('Alignment') types::Type *params; // alignment parameters type ('InterAlignParams') types::Type *pair; // sequence pair type ('SeqPair') types::Type *yield; // inter-align yield type ('InterAlignYield') operator bool() const { return seq && cigar && align && params && pair && yield; } }; InterAlignTypes gatherInterAlignTypes(Module *M) { return {M->getOrRealizeType("seq", {}, seqModule), M->getOrRealizeType("CIGAR", {}, alignModule), M->getOrRealizeType("Alignment", {}, alignModule), M->getOrRealizeType("InterAlignParams", {}, alignModule), M->getOrRealizeType("SeqPair", {}, alignModule), M->getOrRealizeType("InterAlignYield", {}, alignModule)}; } bool isConstOrGlobal(const Value *x) { if (!x) { return false; } else if (auto *v = cast(x)) { return v->getVar()->isGlobal(); } else { return util::isConst(x) || util::isConst(x); } } bool isGlobalVar(Value *x) { if (auto *v = cast(x)) { return v->getVar()->isGlobal(); } return false; } template bool verifyAlignParams(T begin, T end) { enum ParamKind { SI, // supported int SB, // supported bool UI, // unsupported int UB, // unsupported bool }; /* a: int = 2, b: int = 4, ambig: int = 0, gapo: int = 4, gape: int = 2, gapo2: int = -1, gape2: int = -1, bandwidth: int = -1, zdrop: int = -1, end_bonus: int = 0, score_only: bool = False, right: bool = False, generic_sc: bool = False, approx_max: bool = False, approx_drop: bool = False, ext_only: bool = False, rev_cigar: bool = False, splice: bool = False, splice_fwd: bool = False, splice_rev: bool = False, splice_flank: bool = False */ ParamKind kinds[] = { SI, SI, SI, SI, SI, UI, UI, SI, SI, SI, SB, UB, UB, UB, UB, SB, SB, UB, UB, UB, UB, }; int i = 0; for (auto it = begin; it != end; ++it) { Value *v = *it; switch (kinds[i]) { case SI: if (!(isGlobalVar(v) || util::isConst(v))) return false; break; case SB: if (!(isGlobalVar(v) || util::isConst(v))) return false; break; case UI: if (!util::isConst(v, -1)) return false; break; case UB: if (!util::isConst(v, false)) return false; break; default: seqassert(0, "invalid parameters"); } i += 1; } return true; } struct InterAlignFunctionTransformer : public util::Operator { InterAlignTypes *types; std::vector params; void handle(ReturnInstr *x) override { seqassert(!x->getValue(), "function returns"); auto *M = x->getModule(); x->replaceAll(M->Nr(nullptr, /*final=*/true)); } void handle(CallInstr *x) override { if (!params.empty()) return; auto *M = x->getModule(); auto *I = M->getIntType(); auto *B = M->getBoolType(); auto *alignFunc = M->getOrRealizeMethod( types->seq, "align", {types->seq, types->seq, I, I, I, I, I, I, I, I, I, I, B, B, B, B, B, B, B, B, B, B, B}); auto *func = cast(util::getFunc(x->getCallee())); if (!(func && alignFunc && util::match(func, alignFunc) && verifyAlignParams(x->begin() + 2, x->end()))) return; params = std::vector(x->begin(), x->end()); Value *self = x->front(); Value *other = *(x->begin() + 1); Value *extzOnly = params[17]; Value *revCigar = params[18]; Value *yieldValue = (*types->yield)(*self, *other, *extzOnly, *revCigar); auto *yieldOut = M->Nr(yieldValue); auto *yieldIn = M->Nr(types->yield, /*suspend=*/false); auto *alnResult = M->Nr(yieldIn, "aln"); x->replaceAll(M->Nr(util::series(yieldOut), alnResult)); } InterAlignFunctionTransformer(InterAlignTypes *types) : util::Operator(), types(types), params() {} Value *getParams() { // order of 'args': a, b, ambig, gapo, gape, score_only, bandwidth, zdrop, end_bonus std::vector args = {params[2], params[3], params[4], params[5], params[6], params[12], params[9], params[10], params[11]}; return types->params->construct(args); } }; void PipelineInterAlignOptimization::handle(PipelineFlow *p) { if (isParallel(p)) return; auto *M = p->getModule(); auto types = gatherInterAlignTypes(M); if (!types) // bio module not loaded; nothing to do return; PipelineFlow::Stage *prev = nullptr; util::CloneVisitor cv(M); for (auto it = p->begin(); it != p->end(); ++it) { if (auto *func = getStageFunc(*it)) { if (!it->isGenerator() && util::hasAttribute(func, "std.bio.builtin.inter_align") && util::getReturnType(func)->is(M->getVoidType())) { // transform aligning function InterAlignFunctionTransformer aft(&types); auto *clone = cast(cv.forceClone(func)); util::setReturnType(clone, M->getGeneratorType(types.yield)); clone->setGenerator(); clone->getBody()->accept(aft); if (aft.params.empty()) continue; // make sure the arguments are in the correct order auto *inputType = prev->getOutputElementType(); clone = makeStageWrapperFunc(&*it, clone, inputType); auto *coroType = cast(clone->getType()); // vars // following defs are from bio/align.seq const int LEN_LIMIT = 512; const int MAX_SEQ_LEN8 = 128; const int MAX_SEQ_LEN16 = 32768; const unsigned W = SCHED_WIDTH_INTERALIGN; auto *intType = M->getIntType(); auto *intPtrType = M->getPointerType(intType); auto *i32 = M->getIntNType(32, true); auto *parent = cast(getParentFunc()); seqassert(parent, "not in a function"); auto *init = M->Nr(); auto *states = util::makeVar(util::alloc(coroType->getReturnType(), W), init, parent); auto *statesTemp = util::makeVar(util::alloc(coroType->getReturnType(), W), init, parent); auto *pairs = util::makeVar(util::alloc(types.pair, W), init, parent); auto *pairsTemp = util::makeVar(util::alloc(types.pair, W), init, parent); auto *bufRef = util::makeVar(util::alloc(M->getByteType(), LEN_LIMIT * W), init, parent); auto *bufQer = util::makeVar(util::alloc(M->getByteType(), LEN_LIMIT * W), init, parent); auto *hist = util::makeVar(util::alloc(i32, MAX_SEQ_LEN8 + MAX_SEQ_LEN16 + 32), init, parent); auto *filled = util::makeVar(M->getInt(0), init, parent); insertBefore(init); auto *width = M->getInt(W); auto *params = aft.getParams(); std::vector stageArgTypes; std::vector stageArgs; for (auto *arg : *it) { if (arg) { stageArgs.push_back(arg); stageArgTypes.push_back(arg->getType()); } } auto *extraArgs = util::makeTuple(stageArgs, M); auto *schedFunc = M->getOrRealizeFunc( "_interaln_scheduler", {inputType, coroType, pairs->getType(), bufRef->getType(), bufQer->getType(), states->getType(), types.params, hist->getType(), pairsTemp->getType(), statesTemp->getType(), intPtrType, intType, extraArgs->getType()}, {}, alignModule); auto *flushFunc = M->getOrRealizeFunc( "_interaln_flush", {pairs->getType(), bufRef->getType(), bufQer->getType(), states->getType(), M->getIntType(), types.params, hist->getType(), pairsTemp->getType(), statesTemp->getType()}, {}, alignModule); seqassert(schedFunc, "could not realize scheduler"); seqassert(flushFunc, "could not realize flush"); PipelineFlow::Stage stage(replaceStageFunc(*it, schedFunc, cv), {nullptr, M->Nr(clone), pairs, bufRef, bufQer, states, params, hist, pairsTemp, statesTemp, M->Nr(filled->getVar()), width, extraArgs}, /*generator=*/false, /*parallel=*/false); *it = stage; auto *drain = util::call(flushFunc, {pairs, bufRef, bufQer, states, filled, params, hist, pairsTemp, statesTemp}); insertAfter(drain); break; // at most one inter-sequence alignment transformation per pipeline } } prev = &*it; } } } // namespace seq