mirror of https://github.com/open-mmlab/mmocr.git
simplify installation procedure (#188)
* postproc * mv postproc of textdet in mmcv * linting * change deps mmcv 1.3.4 * update deps * linting * Update configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py * Update docs/changelog.md * Update mmocr/models/textdet/postprocess/__init__.py * Update mmocr/version.py * resize&recompress jpg * Update changelog.md Co-authored-by: Hongbin Sun <hongbin306@gmail.com>pull/193/head
parent
cbdd98a1e1
commit
94cd52dc4a
|
@ -56,7 +56,7 @@ jobs:
|
|||
- name: Install PyTorch
|
||||
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install MMCV
|
||||
run: pip install mmcv-full==1.3.1 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
|
||||
run: pip install mmcv-full==1.3.4 -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
|
||||
- name: Install MMDet
|
||||
run: pip install mmdet==2.11.0
|
||||
- name: Install other dependencies
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 618 KiB |
|
@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y git ninja-build libglib2.0-0 libsm6 lib
|
|||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN conda clean --all
|
||||
RUN pip install mmcv-full==1.3.1+torch1.5.0+cu101 -f https://download.openmmlab.com/mmcv/dist/index.html
|
||||
RUN pip install mmcv-full==1.3.4+torch1.5.0+cu101 -f https://download.openmmlab.com/mmcv/dist/index.html
|
||||
|
||||
RUN pip install mmdet==2.11.0
|
||||
|
||||
|
|
|
@ -1,4 +1,13 @@
|
|||
# Changelog
|
||||
## v0.2.0 (16/5/2021)
|
||||
|
||||
**Highlights**
|
||||
|
||||
- MMOCR is compiling-free via moving textdet postprocessing and ops to mmcv 1.3.4 or later.
|
||||
- Add a new OCR downstream task-NER.
|
||||
- Add two new text detection methods: DRRG and FCENet.
|
||||
- Add a new text recognition method: TPS-CRNN.
|
||||
- Add an end-to-end demo.
|
||||
|
||||
## v0.1.0 (7/4/2021)
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
- CUDA 10.1
|
||||
- NCCL 2
|
||||
- GCC 5.4.0 or higher
|
||||
- [MMCV](https://mmcv.readthedocs.io/en/latest/#installation) 1.3.1
|
||||
- [MMCV](https://mmcv.readthedocs.io/en/latest/#installation) 1.3.4
|
||||
- [MMDetection](https://mmdetection.readthedocs.io/en/latest/#installation) 2.11.0
|
||||
|
||||
We have tested the following versions of OS and softwares:
|
||||
|
@ -17,7 +17,7 @@ We have tested the following versions of OS and softwares:
|
|||
- OS: Ubuntu 16.04
|
||||
- CUDA: 10.1
|
||||
- GCC(G++): 5.4.0
|
||||
- MMCV 1.3.1
|
||||
- MMCV 1.3.4
|
||||
- MMDetection 2.11.0
|
||||
- PyTorch 1.5
|
||||
- torchvision 0.6.0
|
||||
|
@ -53,6 +53,7 @@ Please replace ``{cu_version}`` and ``{torch_version}`` in the url to your desir
|
|||
```shell
|
||||
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html
|
||||
```
|
||||
Note that mmocr 0.2.0 or later require mmcv 1.3.4 or later.
|
||||
|
||||
If it compiles during installation, then please check that the cuda version and pytorch version **exactly"" matches the version in the mmcv-full installation command. For example, pytorch 1.7.0 and 1.7.1 are treated differently.
|
||||
|
||||
|
@ -97,7 +98,7 @@ conda activate open-mmlab
|
|||
conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.1 -c pytorch
|
||||
|
||||
# install the latest mmcv-full
|
||||
pip install mmcv-full==1.3.1
|
||||
pip install mmcv-full==1.3.4
|
||||
|
||||
# install mmdetection
|
||||
pip install mmdet==2.11.0
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,402 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* *
|
||||
* Author : Angus Johnson *
|
||||
* Version : 6.4.0 *
|
||||
* Date : 2 July 2015 *
|
||||
* Website : http://www.angusj.com *
|
||||
* Copyright : Angus Johnson 2010-2015 *
|
||||
* *
|
||||
* License: *
|
||||
* Use, modification & distribution is subject to Boost Software License Ver 1. *
|
||||
* http://www.boost.org/LICENSE_1_0.txt *
|
||||
* *
|
||||
* Attributions: *
|
||||
* The code in this library is an extension of Bala Vatti's clipping algorithm: *
|
||||
* "A generic solution to polygon clipping" *
|
||||
* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
|
||||
* http://portal.acm.org/citation.cfm?id=129906 *
|
||||
* *
|
||||
* Computer graphics and geometric modeling: implementation and algorithms *
|
||||
* By Max K. Agoston *
|
||||
* Springer; 1 edition (January 4, 2005) *
|
||||
* http://books.google.com/books?q=vatti+clipping+agoston *
|
||||
* *
|
||||
* See also: *
|
||||
* "Polygon Offsetting by Computing Winding Numbers" *
|
||||
* Paper no. DETC2005-85513 pp. 565-575 *
|
||||
* ASME 2005 International Design Engineering Technical Conferences *
|
||||
* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
|
||||
* September 24-28, 2005 , Long Beach, California, USA *
|
||||
* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
|
||||
* *
|
||||
*******************************************************************************/
|
||||
|
||||
#ifndef clipper_hpp
|
||||
#define clipper_hpp
|
||||
|
||||
#define CLIPPER_VERSION "6.2.6"
|
||||
|
||||
//use_int32: When enabled 32bit ints are used instead of 64bit ints. This
|
||||
//improve performance but coordinate values are limited to the range +/- 46340
|
||||
//#define use_int32
|
||||
|
||||
//use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance.
|
||||
//#define use_xyz
|
||||
|
||||
//use_lines: Enables line clipping. Adds a very minor cost to performance.
|
||||
#define use_lines
|
||||
|
||||
//use_deprecated: Enables temporary support for the obsolete functions
|
||||
//#define use_deprecated
|
||||
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <ostream>
|
||||
#include <functional>
|
||||
#include <queue>
|
||||
|
||||
namespace ClipperLib {
|
||||
|
||||
enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor };
|
||||
enum PolyType { ptSubject, ptClip };
|
||||
//By far the most widely used winding rules for polygon filling are
|
||||
//EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32)
|
||||
//Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL)
|
||||
//see http://glprogramming.com/red/chapter11.html
|
||||
enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative };
|
||||
|
||||
#ifdef use_int32
|
||||
typedef int cInt;
|
||||
static cInt const loRange = 0x7FFF;
|
||||
static cInt const hiRange = 0x7FFF;
|
||||
#else
|
||||
typedef signed long long cInt;
|
||||
static cInt const loRange = 0x3FFFFFFF;
|
||||
static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL;
|
||||
typedef signed long long long64; //used by Int128 class
|
||||
typedef unsigned long long ulong64;
|
||||
|
||||
#endif
|
||||
|
||||
struct IntPoint {
|
||||
cInt X;
|
||||
cInt Y;
|
||||
#ifdef use_xyz
|
||||
cInt Z;
|
||||
IntPoint(cInt x = 0, cInt y = 0, cInt z = 0): X(x), Y(y), Z(z) {};
|
||||
#else
|
||||
IntPoint(cInt x = 0, cInt y = 0): X(x), Y(y) {};
|
||||
#endif
|
||||
|
||||
friend inline bool operator== (const IntPoint& a, const IntPoint& b)
|
||||
{
|
||||
return a.X == b.X && a.Y == b.Y;
|
||||
}
|
||||
friend inline bool operator!= (const IntPoint& a, const IntPoint& b)
|
||||
{
|
||||
return a.X != b.X || a.Y != b.Y;
|
||||
}
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
typedef std::vector< IntPoint > Path;
|
||||
typedef std::vector< Path > Paths;
|
||||
|
||||
inline Path& operator <<(Path& poly, const IntPoint& p) {poly.push_back(p); return poly;}
|
||||
inline Paths& operator <<(Paths& polys, const Path& p) {polys.push_back(p); return polys;}
|
||||
|
||||
std::ostream& operator <<(std::ostream &s, const IntPoint &p);
|
||||
std::ostream& operator <<(std::ostream &s, const Path &p);
|
||||
std::ostream& operator <<(std::ostream &s, const Paths &p);
|
||||
|
||||
struct DoublePoint
|
||||
{
|
||||
double X;
|
||||
double Y;
|
||||
DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {}
|
||||
DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {}
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
#ifdef use_xyz
|
||||
typedef void (*ZFillCallback)(IntPoint& e1bot, IntPoint& e1top, IntPoint& e2bot, IntPoint& e2top, IntPoint& pt);
|
||||
#endif
|
||||
|
||||
enum InitOptions {ioReverseSolution = 1, ioStrictlySimple = 2, ioPreserveCollinear = 4};
|
||||
enum JoinType {jtSquare, jtRound, jtMiter};
|
||||
enum EndType {etClosedPolygon, etClosedLine, etOpenButt, etOpenSquare, etOpenRound};
|
||||
|
||||
class PolyNode;
|
||||
typedef std::vector< PolyNode* > PolyNodes;
|
||||
|
||||
class PolyNode
|
||||
{
|
||||
public:
|
||||
PolyNode();
|
||||
virtual ~PolyNode(){};
|
||||
Path Contour;
|
||||
PolyNodes Childs;
|
||||
PolyNode* Parent;
|
||||
PolyNode* GetNext() const;
|
||||
bool IsHole() const;
|
||||
bool IsOpen() const;
|
||||
int ChildCount() const;
|
||||
private:
|
||||
unsigned Index; //node index in Parent.Childs
|
||||
bool m_IsOpen;
|
||||
JoinType m_jointype;
|
||||
EndType m_endtype;
|
||||
PolyNode* GetNextSiblingUp() const;
|
||||
void AddChild(PolyNode& child);
|
||||
friend class Clipper; //to access Index
|
||||
friend class ClipperOffset;
|
||||
};
|
||||
|
||||
class PolyTree: public PolyNode
|
||||
{
|
||||
public:
|
||||
~PolyTree(){Clear();};
|
||||
PolyNode* GetFirst() const;
|
||||
void Clear();
|
||||
int Total() const;
|
||||
private:
|
||||
PolyNodes AllNodes;
|
||||
friend class Clipper; //to access AllNodes
|
||||
};
|
||||
|
||||
bool Orientation(const Path &poly);
|
||||
double Area(const Path &poly);
|
||||
int PointInPolygon(const IntPoint &pt, const Path &path);
|
||||
|
||||
void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType = pftEvenOdd);
|
||||
void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType = pftEvenOdd);
|
||||
void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd);
|
||||
|
||||
void CleanPolygon(const Path& in_poly, Path& out_poly, double distance = 1.415);
|
||||
void CleanPolygon(Path& poly, double distance = 1.415);
|
||||
void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance = 1.415);
|
||||
void CleanPolygons(Paths& polys, double distance = 1.415);
|
||||
|
||||
void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed);
|
||||
void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed);
|
||||
void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution);
|
||||
|
||||
void PolyTreeToPaths(const PolyTree& polytree, Paths& paths);
|
||||
void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths);
|
||||
void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths);
|
||||
|
||||
void ReversePath(Path& p);
|
||||
void ReversePaths(Paths& p);
|
||||
|
||||
struct IntRect { cInt left; cInt top; cInt right; cInt bottom; };
|
||||
|
||||
//enums that are used internally ...
|
||||
enum EdgeSide { esLeft = 1, esRight = 2};
|
||||
|
||||
//forward declarations (for stuff used internally) ...
|
||||
struct TEdge;
|
||||
struct IntersectNode;
|
||||
struct LocalMinimum;
|
||||
struct OutPt;
|
||||
struct OutRec;
|
||||
struct Join;
|
||||
|
||||
typedef std::vector < OutRec* > PolyOutList;
|
||||
typedef std::vector < TEdge* > EdgeList;
|
||||
typedef std::vector < Join* > JoinList;
|
||||
typedef std::vector < IntersectNode* > IntersectList;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
//ClipperBase is the ancestor to the Clipper class. It should not be
|
||||
//instantiated directly. This class simply abstracts the conversion of sets of
|
||||
//polygon coordinates into edge objects that are stored in a LocalMinima list.
|
||||
class ClipperBase
|
||||
{
|
||||
public:
|
||||
ClipperBase();
|
||||
virtual ~ClipperBase();
|
||||
virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed);
|
||||
bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed);
|
||||
virtual void Clear();
|
||||
IntRect GetBounds();
|
||||
bool PreserveCollinear() {return m_PreserveCollinear;};
|
||||
void PreserveCollinear(bool value) {m_PreserveCollinear = value;};
|
||||
protected:
|
||||
void DisposeLocalMinimaList();
|
||||
TEdge* AddBoundsToLML(TEdge *e, bool IsClosed);
|
||||
virtual void Reset();
|
||||
TEdge* ProcessBound(TEdge* E, bool IsClockwise);
|
||||
void InsertScanbeam(const cInt Y);
|
||||
bool PopScanbeam(cInt &Y);
|
||||
bool LocalMinimaPending();
|
||||
bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin);
|
||||
OutRec* CreateOutRec();
|
||||
void DisposeAllOutRecs();
|
||||
void DisposeOutRec(PolyOutList::size_type index);
|
||||
void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2);
|
||||
void DeleteFromAEL(TEdge *e);
|
||||
void UpdateEdgeIntoAEL(TEdge *&e);
|
||||
|
||||
typedef std::vector<LocalMinimum> MinimaList;
|
||||
MinimaList::iterator m_CurrentLM;
|
||||
MinimaList m_MinimaList;
|
||||
|
||||
bool m_UseFullRange;
|
||||
EdgeList m_edges;
|
||||
bool m_PreserveCollinear;
|
||||
bool m_HasOpenPaths;
|
||||
PolyOutList m_PolyOuts;
|
||||
TEdge *m_ActiveEdges;
|
||||
|
||||
typedef std::priority_queue<cInt> ScanbeamList;
|
||||
ScanbeamList m_Scanbeam;
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
class Clipper : public virtual ClipperBase
|
||||
{
|
||||
public:
|
||||
Clipper(int initOptions = 0);
|
||||
bool Execute(ClipType clipType,
|
||||
Paths &solution,
|
||||
PolyFillType fillType = pftEvenOdd);
|
||||
bool Execute(ClipType clipType,
|
||||
Paths &solution,
|
||||
PolyFillType subjFillType,
|
||||
PolyFillType clipFillType);
|
||||
bool Execute(ClipType clipType,
|
||||
PolyTree &polytree,
|
||||
PolyFillType fillType = pftEvenOdd);
|
||||
bool Execute(ClipType clipType,
|
||||
PolyTree &polytree,
|
||||
PolyFillType subjFillType,
|
||||
PolyFillType clipFillType);
|
||||
bool ReverseSolution() { return m_ReverseOutput; };
|
||||
void ReverseSolution(bool value) {m_ReverseOutput = value;};
|
||||
bool StrictlySimple() {return m_StrictSimple;};
|
||||
void StrictlySimple(bool value) {m_StrictSimple = value;};
|
||||
//set the callback function for z value filling on intersections (otherwise Z is 0)
|
||||
#ifdef use_xyz
|
||||
void ZFillFunction(ZFillCallback zFillFunc);
|
||||
#endif
|
||||
protected:
|
||||
virtual bool ExecuteInternal();
|
||||
private:
|
||||
JoinList m_Joins;
|
||||
JoinList m_GhostJoins;
|
||||
IntersectList m_IntersectList;
|
||||
ClipType m_ClipType;
|
||||
typedef std::list<cInt> MaximaList;
|
||||
MaximaList m_Maxima;
|
||||
TEdge *m_SortedEdges;
|
||||
bool m_ExecuteLocked;
|
||||
PolyFillType m_ClipFillType;
|
||||
PolyFillType m_SubjFillType;
|
||||
bool m_ReverseOutput;
|
||||
bool m_UsingPolyTree;
|
||||
bool m_StrictSimple;
|
||||
#ifdef use_xyz
|
||||
ZFillCallback m_ZFill; //custom callback
|
||||
#endif
|
||||
void SetWindingCount(TEdge& edge);
|
||||
bool IsEvenOddFillType(const TEdge& edge) const;
|
||||
bool IsEvenOddAltFillType(const TEdge& edge) const;
|
||||
void InsertLocalMinimaIntoAEL(const cInt botY);
|
||||
void InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge);
|
||||
void AddEdgeToSEL(TEdge *edge);
|
||||
bool PopEdgeFromSEL(TEdge *&edge);
|
||||
void CopyAELToSEL();
|
||||
void DeleteFromSEL(TEdge *e);
|
||||
void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2);
|
||||
bool IsContributing(const TEdge& edge) const;
|
||||
bool IsTopHorz(const cInt XPos);
|
||||
void DoMaxima(TEdge *e);
|
||||
void ProcessHorizontals();
|
||||
void ProcessHorizontal(TEdge *horzEdge);
|
||||
void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
|
||||
OutPt* AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
|
||||
OutRec* GetOutRec(int idx);
|
||||
void AppendPolygon(TEdge *e1, TEdge *e2);
|
||||
void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt);
|
||||
OutPt* AddOutPt(TEdge *e, const IntPoint &pt);
|
||||
OutPt* GetLastOutPt(TEdge *e);
|
||||
bool ProcessIntersections(const cInt topY);
|
||||
void BuildIntersectList(const cInt topY);
|
||||
void ProcessIntersectList();
|
||||
void ProcessEdgesAtTopOfScanbeam(const cInt topY);
|
||||
void BuildResult(Paths& polys);
|
||||
void BuildResult2(PolyTree& polytree);
|
||||
void SetHoleState(TEdge *e, OutRec *outrec);
|
||||
void DisposeIntersectNodes();
|
||||
bool FixupIntersectionOrder();
|
||||
void FixupOutPolygon(OutRec &outrec);
|
||||
void FixupOutPolyline(OutRec &outrec);
|
||||
bool IsHole(TEdge *e);
|
||||
bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl);
|
||||
void FixHoleLinkage(OutRec &outrec);
|
||||
void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt);
|
||||
void ClearJoins();
|
||||
void ClearGhostJoins();
|
||||
void AddGhostJoin(OutPt *op, const IntPoint offPt);
|
||||
bool JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2);
|
||||
void JoinCommonEdges();
|
||||
void DoSimplePolygons();
|
||||
void FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec);
|
||||
void FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec);
|
||||
void FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec);
|
||||
#ifdef use_xyz
|
||||
void SetZ(IntPoint& pt, TEdge& e1, TEdge& e2);
|
||||
#endif
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
class ClipperOffset
|
||||
{
|
||||
public:
|
||||
ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25);
|
||||
~ClipperOffset();
|
||||
void AddPath(const Path& path, JoinType joinType, EndType endType);
|
||||
void AddPaths(const Paths& paths, JoinType joinType, EndType endType);
|
||||
void Execute(Paths& solution, double delta);
|
||||
void Execute(PolyTree& solution, double delta);
|
||||
void Clear();
|
||||
double MiterLimit;
|
||||
double ArcTolerance;
|
||||
private:
|
||||
Paths m_destPolys;
|
||||
Path m_srcPoly;
|
||||
Path m_destPoly;
|
||||
std::vector<DoublePoint> m_normals;
|
||||
double m_delta, m_sinA, m_sin, m_cos;
|
||||
double m_miterLim, m_StepsPerRad;
|
||||
IntPoint m_lowest;
|
||||
PolyNode m_polyNodes;
|
||||
|
||||
void FixOrientations();
|
||||
void DoOffset(double delta);
|
||||
void OffsetPoint(int j, int& k, JoinType jointype);
|
||||
void DoSquare(int j, int k);
|
||||
void DoMiter(int j, int k, double r);
|
||||
void DoRound(int j, int k);
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
class clipperException : public std::exception
|
||||
{
|
||||
public:
|
||||
clipperException(const char* description): m_descr(description) {}
|
||||
virtual ~clipperException() throw() {}
|
||||
virtual const char* what() const throw() {return m_descr.c_str();}
|
||||
private:
|
||||
std::string m_descr;
|
||||
};
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
} //ClipperLib namespace
|
||||
|
||||
#endif //clipper_hpp
|
|
@ -1,492 +0,0 @@
|
|||
/*
|
||||
pybind11/attr.h: Infrastructure for processing custom
|
||||
type and function attributes
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cast.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
/// \addtogroup annotations
|
||||
/// @{
|
||||
|
||||
/// Annotation for methods
|
||||
struct is_method { handle class_; is_method(const handle &c) : class_(c) { } };
|
||||
|
||||
/// Annotation for operators
|
||||
struct is_operator { };
|
||||
|
||||
/// Annotation for parent scope
|
||||
struct scope { handle value; scope(const handle &s) : value(s) { } };
|
||||
|
||||
/// Annotation for documentation
|
||||
struct doc { const char *value; doc(const char *value) : value(value) { } };
|
||||
|
||||
/// Annotation for function names
|
||||
struct name { const char *value; name(const char *value) : value(value) { } };
|
||||
|
||||
/// Annotation indicating that a function is an overload associated with a given "sibling"
|
||||
struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } };
|
||||
|
||||
/// Annotation indicating that a class derives from another given type
|
||||
template <typename T> struct base {
|
||||
PYBIND11_DEPRECATED("base<T>() was deprecated in favor of specifying 'T' as a template argument to class_")
|
||||
base() { }
|
||||
};
|
||||
|
||||
/// Keep patient alive while nurse lives
|
||||
template <size_t Nurse, size_t Patient> struct keep_alive { };
|
||||
|
||||
/// Annotation indicating that a class is involved in a multiple inheritance relationship
|
||||
struct multiple_inheritance { };
|
||||
|
||||
/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class
|
||||
struct dynamic_attr { };
|
||||
|
||||
/// Annotation which enables the buffer protocol for a type
|
||||
struct buffer_protocol { };
|
||||
|
||||
/// Annotation which requests that a special metaclass is created for a type
|
||||
struct metaclass {
|
||||
handle value;
|
||||
|
||||
PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.")
|
||||
metaclass() {}
|
||||
|
||||
/// Override pybind11's default metaclass
|
||||
explicit metaclass(handle value) : value(value) { }
|
||||
};
|
||||
|
||||
/// Annotation that marks a class as local to the module:
|
||||
struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } };
|
||||
|
||||
/// Annotation to mark enums as an arithmetic type
|
||||
struct arithmetic { };
|
||||
|
||||
/** \rst
|
||||
A call policy which places one or more guard variables (``Ts...``) around the function call.
|
||||
|
||||
For example, this definition:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
m.def("foo", foo, py::call_guard<T>());
|
||||
|
||||
is equivalent to the following pseudocode:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
m.def("foo", [](args...) {
|
||||
T scope_guard;
|
||||
return foo(args...); // forwarded arguments
|
||||
});
|
||||
\endrst */
|
||||
template <typename... Ts> struct call_guard;
|
||||
|
||||
template <> struct call_guard<> { using type = detail::void_type; };
|
||||
|
||||
template <typename T>
|
||||
struct call_guard<T> {
|
||||
static_assert(std::is_default_constructible<T>::value,
|
||||
"The guard type must be default constructible");
|
||||
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
struct call_guard<T, Ts...> {
|
||||
struct type {
|
||||
T guard{}; // Compose multiple guard types with left-to-right default-constructor order
|
||||
typename call_guard<Ts...>::type next{};
|
||||
};
|
||||
};
|
||||
|
||||
/// @} annotations
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
/* Forward declarations */
|
||||
enum op_id : int;
|
||||
enum op_type : int;
|
||||
struct undefined_t;
|
||||
template <op_id id, op_type ot, typename L = undefined_t, typename R = undefined_t> struct op_;
|
||||
inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret);
|
||||
|
||||
/// Internal data structure which holds metadata about a keyword argument
|
||||
struct argument_record {
|
||||
const char *name; ///< Argument name
|
||||
const char *descr; ///< Human-readable version of the argument value
|
||||
handle value; ///< Associated Python object
|
||||
bool convert : 1; ///< True if the argument is allowed to convert when loading
|
||||
bool none : 1; ///< True if None is allowed when loading
|
||||
|
||||
argument_record(const char *name, const char *descr, handle value, bool convert, bool none)
|
||||
: name(name), descr(descr), value(value), convert(convert), none(none) { }
|
||||
};
|
||||
|
||||
/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.)
|
||||
struct function_record {
|
||||
function_record()
|
||||
: is_constructor(false), is_new_style_constructor(false), is_stateless(false),
|
||||
is_operator(false), has_args(false), has_kwargs(false), is_method(false) { }
|
||||
|
||||
/// Function name
|
||||
char *name = nullptr; /* why no C++ strings? They generate heavier code.. */
|
||||
|
||||
// User-specified documentation string
|
||||
char *doc = nullptr;
|
||||
|
||||
/// Human-readable version of the function signature
|
||||
char *signature = nullptr;
|
||||
|
||||
/// List of registered keyword arguments
|
||||
std::vector<argument_record> args;
|
||||
|
||||
/// Pointer to lambda function which converts arguments and performs the actual call
|
||||
handle (*impl) (function_call &) = nullptr;
|
||||
|
||||
/// Storage for the wrapped function pointer and captured data, if any
|
||||
void *data[3] = { };
|
||||
|
||||
/// Pointer to custom destructor for 'data' (if needed)
|
||||
void (*free_data) (function_record *ptr) = nullptr;
|
||||
|
||||
/// Return value policy associated with this function
|
||||
return_value_policy policy = return_value_policy::automatic;
|
||||
|
||||
/// True if name == '__init__'
|
||||
bool is_constructor : 1;
|
||||
|
||||
/// True if this is a new-style `__init__` defined in `detail/init.h`
|
||||
bool is_new_style_constructor : 1;
|
||||
|
||||
/// True if this is a stateless function pointer
|
||||
bool is_stateless : 1;
|
||||
|
||||
/// True if this is an operator (__add__), etc.
|
||||
bool is_operator : 1;
|
||||
|
||||
/// True if the function has a '*args' argument
|
||||
bool has_args : 1;
|
||||
|
||||
/// True if the function has a '**kwargs' argument
|
||||
bool has_kwargs : 1;
|
||||
|
||||
/// True if this is a method
|
||||
bool is_method : 1;
|
||||
|
||||
/// Number of arguments (including py::args and/or py::kwargs, if present)
|
||||
std::uint16_t nargs;
|
||||
|
||||
/// Python method object
|
||||
PyMethodDef *def = nullptr;
|
||||
|
||||
/// Python handle to the parent scope (a class or a module)
|
||||
handle scope;
|
||||
|
||||
/// Python handle to the sibling function representing an overload chain
|
||||
handle sibling;
|
||||
|
||||
/// Pointer to next overload
|
||||
function_record *next = nullptr;
|
||||
};
|
||||
|
||||
/// Special data structure which (temporarily) holds metadata about a bound class
|
||||
struct type_record {
|
||||
PYBIND11_NOINLINE type_record()
|
||||
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { }
|
||||
|
||||
/// Handle to the parent scope
|
||||
handle scope;
|
||||
|
||||
/// Name of the class
|
||||
const char *name = nullptr;
|
||||
|
||||
// Pointer to RTTI type_info data structure
|
||||
const std::type_info *type = nullptr;
|
||||
|
||||
/// How large is the underlying C++ type?
|
||||
size_t type_size = 0;
|
||||
|
||||
/// What is the alignment of the underlying C++ type?
|
||||
size_t type_align = 0;
|
||||
|
||||
/// How large is the type's holder?
|
||||
size_t holder_size = 0;
|
||||
|
||||
/// The global operator new can be overridden with a class-specific variant
|
||||
void *(*operator_new)(size_t) = nullptr;
|
||||
|
||||
/// Function pointer to class_<..>::init_instance
|
||||
void (*init_instance)(instance *, const void *) = nullptr;
|
||||
|
||||
/// Function pointer to class_<..>::dealloc
|
||||
void (*dealloc)(detail::value_and_holder &) = nullptr;
|
||||
|
||||
/// List of base classes of the newly created type
|
||||
list bases;
|
||||
|
||||
/// Optional docstring
|
||||
const char *doc = nullptr;
|
||||
|
||||
/// Custom metaclass (optional)
|
||||
handle metaclass;
|
||||
|
||||
/// Multiple inheritance marker
|
||||
bool multiple_inheritance : 1;
|
||||
|
||||
/// Does the class manage a __dict__?
|
||||
bool dynamic_attr : 1;
|
||||
|
||||
/// Does the class implement the buffer protocol?
|
||||
bool buffer_protocol : 1;
|
||||
|
||||
/// Is the default (unique_ptr) holder type used?
|
||||
bool default_holder : 1;
|
||||
|
||||
/// Is the class definition local to the module shared object?
|
||||
bool module_local : 1;
|
||||
|
||||
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
|
||||
auto base_info = detail::get_type_info(base, false);
|
||||
if (!base_info) {
|
||||
std::string tname(base.name());
|
||||
detail::clean_type_id(tname);
|
||||
pybind11_fail("generic_type: type \"" + std::string(name) +
|
||||
"\" referenced unknown base type \"" + tname + "\"");
|
||||
}
|
||||
|
||||
if (default_holder != base_info->default_holder) {
|
||||
std::string tname(base.name());
|
||||
detail::clean_type_id(tname);
|
||||
pybind11_fail("generic_type: type \"" + std::string(name) + "\" " +
|
||||
(default_holder ? "does not have" : "has") +
|
||||
" a non-default holder type while its base \"" + tname + "\" " +
|
||||
(base_info->default_holder ? "does not" : "does"));
|
||||
}
|
||||
|
||||
bases.append((PyObject *) base_info->type);
|
||||
|
||||
if (base_info->type->tp_dictoffset != 0)
|
||||
dynamic_attr = true;
|
||||
|
||||
if (caster)
|
||||
base_info->implicit_casts.emplace_back(type, caster);
|
||||
}
|
||||
};
|
||||
|
||||
inline function_call::function_call(const function_record &f, handle p) :
|
||||
func(f), parent(p) {
|
||||
args.reserve(f.nargs);
|
||||
args_convert.reserve(f.nargs);
|
||||
}
|
||||
|
||||
/// Tag for a new-style `__init__` defined in `detail/init.h`
|
||||
struct is_new_style_constructor { };
|
||||
|
||||
/**
|
||||
* Partial template specializations to process custom attributes provided to
|
||||
* cpp_function_ and class_. These are either used to initialize the respective
|
||||
* fields in the type_record and function_record data structures or executed at
|
||||
* runtime to deal with custom call policies (e.g. keep_alive).
|
||||
*/
|
||||
template <typename T, typename SFINAE = void> struct process_attribute;
|
||||
|
||||
template <typename T> struct process_attribute_default {
|
||||
/// Default implementation: do nothing
|
||||
static void init(const T &, function_record *) { }
|
||||
static void init(const T &, type_record *) { }
|
||||
static void precall(function_call &) { }
|
||||
static void postcall(function_call &, handle) { }
|
||||
};
|
||||
|
||||
/// Process an attribute specifying the function's name
|
||||
template <> struct process_attribute<name> : process_attribute_default<name> {
|
||||
static void init(const name &n, function_record *r) { r->name = const_cast<char *>(n.value); }
|
||||
};
|
||||
|
||||
/// Process an attribute specifying the function's docstring
|
||||
template <> struct process_attribute<doc> : process_attribute_default<doc> {
|
||||
static void init(const doc &n, function_record *r) { r->doc = const_cast<char *>(n.value); }
|
||||
};
|
||||
|
||||
/// Process an attribute specifying the function's docstring (provided as a C-style string)
|
||||
template <> struct process_attribute<const char *> : process_attribute_default<const char *> {
|
||||
static void init(const char *d, function_record *r) { r->doc = const_cast<char *>(d); }
|
||||
static void init(const char *d, type_record *r) { r->doc = const_cast<char *>(d); }
|
||||
};
|
||||
template <> struct process_attribute<char *> : process_attribute<const char *> { };
|
||||
|
||||
/// Process an attribute indicating the function's return value policy
|
||||
template <> struct process_attribute<return_value_policy> : process_attribute_default<return_value_policy> {
|
||||
static void init(const return_value_policy &p, function_record *r) { r->policy = p; }
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates that this is an overloaded function associated with a given sibling
|
||||
template <> struct process_attribute<sibling> : process_attribute_default<sibling> {
|
||||
static void init(const sibling &s, function_record *r) { r->sibling = s.value; }
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates that this function is a method
|
||||
template <> struct process_attribute<is_method> : process_attribute_default<is_method> {
|
||||
static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; }
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates the parent scope of a method
|
||||
template <> struct process_attribute<scope> : process_attribute_default<scope> {
|
||||
static void init(const scope &s, function_record *r) { r->scope = s.value; }
|
||||
};
|
||||
|
||||
/// Process an attribute which indicates that this function is an operator
|
||||
template <> struct process_attribute<is_operator> : process_attribute_default<is_operator> {
|
||||
static void init(const is_operator &, function_record *r) { r->is_operator = true; }
|
||||
};
|
||||
|
||||
template <> struct process_attribute<is_new_style_constructor> : process_attribute_default<is_new_style_constructor> {
|
||||
static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; }
|
||||
};
|
||||
|
||||
/// Process a keyword argument attribute (*without* a default value)
|
||||
template <> struct process_attribute<arg> : process_attribute_default<arg> {
|
||||
static void init(const arg &a, function_record *r) {
|
||||
if (r->is_method && r->args.empty())
|
||||
r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/);
|
||||
r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none);
|
||||
}
|
||||
};
|
||||
|
||||
/// Process a keyword argument attribute (*with* a default value)
|
||||
template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
|
||||
static void init(const arg_v &a, function_record *r) {
|
||||
if (r->is_method && r->args.empty())
|
||||
r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/);
|
||||
|
||||
if (!a.value) {
|
||||
#if !defined(NDEBUG)
|
||||
std::string descr("'");
|
||||
if (a.name) descr += std::string(a.name) + ": ";
|
||||
descr += a.type + "'";
|
||||
if (r->is_method) {
|
||||
if (r->name)
|
||||
descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'";
|
||||
else
|
||||
descr += " in method of '" + (std::string) str(r->scope) + "'";
|
||||
} else if (r->name) {
|
||||
descr += " in function '" + (std::string) r->name + "'";
|
||||
}
|
||||
pybind11_fail("arg(): could not convert default argument "
|
||||
+ descr + " into a Python object (type not registered yet?)");
|
||||
#else
|
||||
pybind11_fail("arg(): could not convert default argument "
|
||||
"into a Python object (type not registered yet?). "
|
||||
"Compile in debug mode for more information.");
|
||||
#endif
|
||||
}
|
||||
r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none);
|
||||
}
|
||||
};
|
||||
|
||||
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
|
||||
template <typename T>
|
||||
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
|
||||
static void init(const handle &h, type_record *r) { r->bases.append(h); }
|
||||
};
|
||||
|
||||
/// Process a parent class attribute (deprecated, does not support multiple inheritance)
|
||||
template <typename T>
|
||||
struct process_attribute<base<T>> : process_attribute_default<base<T>> {
|
||||
static void init(const base<T> &, type_record *r) { r->add_base(typeid(T), nullptr); }
|
||||
};
|
||||
|
||||
/// Process a multiple inheritance attribute
|
||||
template <>
|
||||
struct process_attribute<multiple_inheritance> : process_attribute_default<multiple_inheritance> {
|
||||
static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr> {
|
||||
static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct process_attribute<buffer_protocol> : process_attribute_default<buffer_protocol> {
|
||||
static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct process_attribute<metaclass> : process_attribute_default<metaclass> {
|
||||
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct process_attribute<module_local> : process_attribute_default<module_local> {
|
||||
static void init(const module_local &l, type_record *r) { r->module_local = l.value; }
|
||||
};
|
||||
|
||||
/// Process an 'arithmetic' attribute for enums (does nothing here)
|
||||
template <>
|
||||
struct process_attribute<arithmetic> : process_attribute_default<arithmetic> {};
|
||||
|
||||
template <typename... Ts>
|
||||
struct process_attribute<call_guard<Ts...>> : process_attribute_default<call_guard<Ts...>> { };
|
||||
|
||||
/**
|
||||
* Process a keep_alive call policy -- invokes keep_alive_impl during the
|
||||
* pre-call handler if both Nurse, Patient != 0 and use the post-call handler
|
||||
* otherwise
|
||||
*/
|
||||
template <size_t Nurse, size_t Patient> struct process_attribute<keep_alive<Nurse, Patient>> : public process_attribute_default<keep_alive<Nurse, Patient>> {
|
||||
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
|
||||
static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); }
|
||||
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N != 0 && P != 0, int> = 0>
|
||||
static void postcall(function_call &, handle) { }
|
||||
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
|
||||
static void precall(function_call &) { }
|
||||
template <size_t N = Nurse, size_t P = Patient, enable_if_t<N == 0 || P == 0, int> = 0>
|
||||
static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); }
|
||||
};
|
||||
|
||||
/// Recursively iterate over variadic template arguments
|
||||
template <typename... Args> struct process_attributes {
|
||||
static void init(const Args&... args, function_record *r) {
|
||||
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void init(const Args&... args, type_record *r) {
|
||||
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::init(args, r), 0) ... };
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void precall(function_call &call) {
|
||||
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::precall(call), 0) ... };
|
||||
ignore_unused(unused);
|
||||
}
|
||||
static void postcall(function_call &call, handle fn_ret) {
|
||||
int unused[] = { 0, (process_attribute<typename std::decay<Args>::type>::postcall(call, fn_ret), 0) ... };
|
||||
ignore_unused(unused);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using is_call_guard = is_instantiation<call_guard, T>;
|
||||
|
||||
/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found)
|
||||
template <typename... Extra>
|
||||
using extract_guard_t = typename exactly_one_t<is_call_guard, call_guard<>, Extra...>::type;
|
||||
|
||||
/// Check the number of named arguments at compile time
|
||||
template <typename... Extra,
|
||||
size_t named = constexpr_sum(std::is_base_of<arg, Extra>::value...),
|
||||
size_t self = constexpr_sum(std::is_same<is_method, Extra>::value...)>
|
||||
constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) {
|
||||
return named == 0 || (self + named + has_args + has_kwargs) == nargs;
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,108 +0,0 @@
|
|||
/*
|
||||
pybind11/buffer_info.h: Python buffer object interface
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "detail/common.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
/// Information record describing a Python buffer object
|
||||
struct buffer_info {
|
||||
void *ptr = nullptr; // Pointer to the underlying storage
|
||||
ssize_t itemsize = 0; // Size of individual items in bytes
|
||||
ssize_t size = 0; // Total number of entries
|
||||
std::string format; // For homogeneous buffers, this should be set to format_descriptor<T>::format()
|
||||
ssize_t ndim = 0; // Number of dimensions
|
||||
std::vector<ssize_t> shape; // Shape of the tensor (1 entry per dimension)
|
||||
std::vector<ssize_t> strides; // Number of entries between adjacent entries (for each per dimension)
|
||||
|
||||
buffer_info() { }
|
||||
|
||||
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
|
||||
detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
|
||||
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
|
||||
shape(std::move(shape_in)), strides(std::move(strides_in)) {
|
||||
if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
|
||||
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
|
||||
for (size_t i = 0; i < (size_t) ndim; ++i)
|
||||
size *= shape[i];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
buffer_info(T *ptr, detail::any_container<ssize_t> shape_in, detail::any_container<ssize_t> strides_in)
|
||||
: buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor<T>::format(), static_cast<ssize_t>(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
|
||||
|
||||
buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
|
||||
: buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
|
||||
|
||||
template <typename T>
|
||||
buffer_info(T *ptr, ssize_t size)
|
||||
: buffer_info(ptr, sizeof(T), format_descriptor<T>::format(), size) { }
|
||||
|
||||
explicit buffer_info(Py_buffer *view, bool ownview = true)
|
||||
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
|
||||
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
|
||||
this->view = view;
|
||||
this->ownview = ownview;
|
||||
}
|
||||
|
||||
buffer_info(const buffer_info &) = delete;
|
||||
buffer_info& operator=(const buffer_info &) = delete;
|
||||
|
||||
buffer_info(buffer_info &&other) {
|
||||
(*this) = std::move(other);
|
||||
}
|
||||
|
||||
buffer_info& operator=(buffer_info &&rhs) {
|
||||
ptr = rhs.ptr;
|
||||
itemsize = rhs.itemsize;
|
||||
size = rhs.size;
|
||||
format = std::move(rhs.format);
|
||||
ndim = rhs.ndim;
|
||||
shape = std::move(rhs.shape);
|
||||
strides = std::move(rhs.strides);
|
||||
std::swap(view, rhs.view);
|
||||
std::swap(ownview, rhs.ownview);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~buffer_info() {
|
||||
if (view && ownview) { PyBuffer_Release(view); delete view; }
|
||||
}
|
||||
|
||||
private:
|
||||
struct private_ctr_tag { };
|
||||
|
||||
buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
|
||||
detail::any_container<ssize_t> &&shape_in, detail::any_container<ssize_t> &&strides_in)
|
||||
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
|
||||
|
||||
Py_buffer *view = nullptr;
|
||||
bool ownview = false;
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename T, typename SFINAE = void> struct compare_buffer_info {
|
||||
static bool compare(const buffer_info& b) {
|
||||
return b.format == format_descriptor<T>::format() && b.itemsize == (ssize_t) sizeof(T);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
|
||||
static bool compare(const buffer_info& b) {
|
||||
return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
|
||||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
|
||||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
File diff suppressed because it is too large
Load Diff
|
@ -1,162 +0,0 @@
|
|||
/*
|
||||
pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
|
||||
|
||||
Copyright (c) 2016 Trent Houliston <trent@houliston.me> and
|
||||
Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include <cmath>
|
||||
#include <ctime>
|
||||
#include <chrono>
|
||||
#include <datetime.h>
|
||||
|
||||
// Backport the PyDateTime_DELTA functions from Python3.3 if required
|
||||
#ifndef PyDateTime_DELTA_GET_DAYS
|
||||
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
|
||||
#endif
|
||||
#ifndef PyDateTime_DELTA_GET_SECONDS
|
||||
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
|
||||
#endif
|
||||
#ifndef PyDateTime_DELTA_GET_MICROSECONDS
|
||||
#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename type> class duration_caster {
|
||||
public:
|
||||
typedef typename type::rep rep;
|
||||
typedef typename type::period period;
|
||||
|
||||
typedef std::chrono::duration<uint_fast32_t, std::ratio<86400>> days;
|
||||
|
||||
bool load(handle src, bool) {
|
||||
using namespace std::chrono;
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
|
||||
|
||||
if (!src) return false;
|
||||
// If invoked with datetime.delta object
|
||||
if (PyDelta_Check(src.ptr())) {
|
||||
value = type(duration_cast<duration<rep, period>>(
|
||||
days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
|
||||
+ seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
|
||||
+ microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
|
||||
return true;
|
||||
}
|
||||
// If invoked with a float we assume it is seconds and convert
|
||||
else if (PyFloat_Check(src.ptr())) {
|
||||
value = type(duration_cast<duration<rep, period>>(duration<double>(PyFloat_AsDouble(src.ptr()))));
|
||||
return true;
|
||||
}
|
||||
else return false;
|
||||
}
|
||||
|
||||
// If this is a duration just return it back
|
||||
static const std::chrono::duration<rep, period>& get_duration(const std::chrono::duration<rep, period> &src) {
|
||||
return src;
|
||||
}
|
||||
|
||||
// If this is a time_point get the time_since_epoch
|
||||
template <typename Clock> static std::chrono::duration<rep, period> get_duration(const std::chrono::time_point<Clock, std::chrono::duration<rep, period>> &src) {
|
||||
return src.time_since_epoch();
|
||||
}
|
||||
|
||||
static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
using namespace std::chrono;
|
||||
|
||||
// Use overloaded function to get our duration from our source
|
||||
// Works out if it is a duration or time_point and get the duration
|
||||
auto d = get_duration(src);
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
|
||||
|
||||
// Declare these special duration types so the conversions happen with the correct primitive types (int)
|
||||
using dd_t = duration<int, std::ratio<86400>>;
|
||||
using ss_t = duration<int, std::ratio<1>>;
|
||||
using us_t = duration<int, std::micro>;
|
||||
|
||||
auto dd = duration_cast<dd_t>(d);
|
||||
auto subd = d - dd;
|
||||
auto ss = duration_cast<ss_t>(subd);
|
||||
auto us = duration_cast<us_t>(subd - ss);
|
||||
return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
|
||||
};
|
||||
|
||||
// This is for casting times on the system clock into datetime.datetime instances
|
||||
template <typename Duration> class type_caster<std::chrono::time_point<std::chrono::system_clock, Duration>> {
|
||||
public:
|
||||
typedef std::chrono::time_point<std::chrono::system_clock, Duration> type;
|
||||
bool load(handle src, bool) {
|
||||
using namespace std::chrono;
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
|
||||
|
||||
if (!src) return false;
|
||||
if (PyDateTime_Check(src.ptr())) {
|
||||
std::tm cal;
|
||||
cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
|
||||
cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
|
||||
cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
|
||||
cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
|
||||
cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
|
||||
cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
|
||||
cal.tm_isdst = -1;
|
||||
|
||||
value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
|
||||
return true;
|
||||
}
|
||||
else return false;
|
||||
}
|
||||
|
||||
static handle cast(const std::chrono::time_point<std::chrono::system_clock, Duration> &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
using namespace std::chrono;
|
||||
|
||||
// Lazy initialise the PyDateTime import
|
||||
if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
|
||||
|
||||
std::time_t tt = system_clock::to_time_t(src);
|
||||
// this function uses static memory so it's best to copy it out asap just in case
|
||||
// otherwise other code that is using localtime may break this (not just python code)
|
||||
std::tm localtime = *std::localtime(&tt);
|
||||
|
||||
// Declare these special duration types so the conversions happen with the correct primitive types (int)
|
||||
using us_t = duration<int, std::micro>;
|
||||
|
||||
return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
|
||||
localtime.tm_mon + 1,
|
||||
localtime.tm_mday,
|
||||
localtime.tm_hour,
|
||||
localtime.tm_min,
|
||||
localtime.tm_sec,
|
||||
(duration_cast<us_t>(src.time_since_epoch() % seconds(1))).count());
|
||||
}
|
||||
PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
|
||||
};
|
||||
|
||||
// Other clocks that are not the system clock are not measured as datetime.datetime objects
|
||||
// since they are not measured on calendar time. So instead we just make them timedeltas
|
||||
// Or if they have passed us a time as a float we convert that
|
||||
template <typename Clock, typename Duration> class type_caster<std::chrono::time_point<Clock, Duration>>
|
||||
: public duration_caster<std::chrono::time_point<Clock, Duration>> {
|
||||
};
|
||||
|
||||
template <typename Rep, typename Period> class type_caster<std::chrono::duration<Rep, Period>>
|
||||
: public duration_caster<std::chrono::duration<Rep, Period>> {
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,603 +0,0 @@
|
|||
/*
|
||||
pybind11/class_support.h: Python C API implementation details for py::class_
|
||||
|
||||
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "attr.h"
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
inline PyTypeObject *type_incref(PyTypeObject *type) {
|
||||
Py_INCREF(type);
|
||||
return type;
|
||||
}
|
||||
|
||||
#if !defined(PYPY_VERSION)
|
||||
|
||||
/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance.
|
||||
extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) {
|
||||
return PyProperty_Type.tp_descr_get(self, cls, cls);
|
||||
}
|
||||
|
||||
/// `pybind11_static_property.__set__()`: Just like the above `__get__()`.
|
||||
extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) {
|
||||
PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj);
|
||||
return PyProperty_Type.tp_descr_set(self, cls, value);
|
||||
}
|
||||
|
||||
/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()`
|
||||
methods are modified to always use the object type instead of a concrete instance.
|
||||
Return value: New reference. */
|
||||
inline PyTypeObject *make_static_property_type() {
|
||||
constexpr auto *name = "pybind11_static_property";
|
||||
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
||||
|
||||
/* Danger zone: from now (and until PyType_Ready), make sure to
|
||||
issue no Python C API calls which could potentially invoke the
|
||||
garbage collector (the GC will call type_traverse(), which will in
|
||||
turn find the newly constructed type in an invalid state) */
|
||||
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
|
||||
if (!heap_type)
|
||||
pybind11_fail("make_static_property_type(): error allocating type!");
|
||||
|
||||
heap_type->ht_name = name_obj.inc_ref().ptr();
|
||||
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3
|
||||
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
||||
#endif
|
||||
|
||||
auto type = &heap_type->ht_type;
|
||||
type->tp_name = name;
|
||||
type->tp_base = type_incref(&PyProperty_Type);
|
||||
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
||||
type->tp_descr_get = pybind11_static_get;
|
||||
type->tp_descr_set = pybind11_static_set;
|
||||
|
||||
if (PyType_Ready(type) < 0)
|
||||
pybind11_fail("make_static_property_type(): failure in PyType_Ready()!");
|
||||
|
||||
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
#else // PYPY
|
||||
|
||||
/** PyPy has some issues with the above C API, so we evaluate Python code instead.
|
||||
This function will only be called once so performance isn't really a concern.
|
||||
Return value: New reference. */
|
||||
inline PyTypeObject *make_static_property_type() {
|
||||
auto d = dict();
|
||||
PyObject *result = PyRun_String(R"(\
|
||||
class pybind11_static_property(property):
|
||||
def __get__(self, obj, cls):
|
||||
return property.__get__(self, cls, cls)
|
||||
|
||||
def __set__(self, obj, value):
|
||||
cls = obj if isinstance(obj, type) else type(obj)
|
||||
property.__set__(self, cls, value)
|
||||
)", Py_file_input, d.ptr(), d.ptr()
|
||||
);
|
||||
if (result == nullptr)
|
||||
throw error_already_set();
|
||||
Py_DECREF(result);
|
||||
return (PyTypeObject *) d["pybind11_static_property"].cast<object>().release().ptr();
|
||||
}
|
||||
|
||||
#endif // PYPY
|
||||
|
||||
/** Types with static properties need to handle `Type.static_prop = x` in a specific way.
|
||||
By default, Python replaces the `static_property` itself, but for wrapped C++ types
|
||||
we need to call `static_property.__set__()` in order to propagate the new value to
|
||||
the underlying C++ data structure. */
|
||||
extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) {
|
||||
// Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw
|
||||
// descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`).
|
||||
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
|
||||
|
||||
// The following assignment combinations are possible:
|
||||
// 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)`
|
||||
// 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop`
|
||||
// 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment
|
||||
const auto static_prop = (PyObject *) get_internals().static_property_type;
|
||||
const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop)
|
||||
&& !PyObject_IsInstance(value, static_prop);
|
||||
if (call_descr_set) {
|
||||
// Call `static_property.__set__()` instead of replacing the `static_property`.
|
||||
#if !defined(PYPY_VERSION)
|
||||
return Py_TYPE(descr)->tp_descr_set(descr, obj, value);
|
||||
#else
|
||||
if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) {
|
||||
Py_DECREF(result);
|
||||
return 0;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
// Replace existing attribute.
|
||||
return PyType_Type.tp_setattro(obj, name, value);
|
||||
}
|
||||
}
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
/**
|
||||
* Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing
|
||||
* methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function,
|
||||
* when called on a class, or a PyMethod, when called on an instance. Override that behaviour here
|
||||
* to do a special case bypass for PyInstanceMethod_Types.
|
||||
*/
|
||||
extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) {
|
||||
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
|
||||
if (descr && PyInstanceMethod_Check(descr)) {
|
||||
Py_INCREF(descr);
|
||||
return descr;
|
||||
}
|
||||
else {
|
||||
return PyType_Type.tp_getattro(obj, name);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/** This metaclass is assigned by default to all pybind11 types and is required in order
|
||||
for static properties to function correctly. Users may override this using `py::metaclass`.
|
||||
Return value: New reference. */
|
||||
inline PyTypeObject* make_default_metaclass() {
|
||||
constexpr auto *name = "pybind11_type";
|
||||
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
||||
|
||||
/* Danger zone: from now (and until PyType_Ready), make sure to
|
||||
issue no Python C API calls which could potentially invoke the
|
||||
garbage collector (the GC will call type_traverse(), which will in
|
||||
turn find the newly constructed type in an invalid state) */
|
||||
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
|
||||
if (!heap_type)
|
||||
pybind11_fail("make_default_metaclass(): error allocating metaclass!");
|
||||
|
||||
heap_type->ht_name = name_obj.inc_ref().ptr();
|
||||
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3
|
||||
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
||||
#endif
|
||||
|
||||
auto type = &heap_type->ht_type;
|
||||
type->tp_name = name;
|
||||
type->tp_base = type_incref(&PyType_Type);
|
||||
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
||||
|
||||
type->tp_setattro = pybind11_meta_setattro;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
type->tp_getattro = pybind11_meta_getattro;
|
||||
#endif
|
||||
|
||||
if (PyType_Ready(type) < 0)
|
||||
pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!");
|
||||
|
||||
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
/// For multiple inheritance types we need to recursively register/deregister base pointers for any
|
||||
/// base classes with pointers that are difference from the instance value pointer so that we can
|
||||
/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs.
|
||||
inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self,
|
||||
bool (*f)(void * /*parentptr*/, instance * /*self*/)) {
|
||||
for (handle h : reinterpret_borrow<tuple>(tinfo->type->tp_bases)) {
|
||||
if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) {
|
||||
for (auto &c : parent_tinfo->implicit_casts) {
|
||||
if (c.first == tinfo->cpptype) {
|
||||
auto *parentptr = c.second(valueptr);
|
||||
if (parentptr != valueptr)
|
||||
f(parentptr, self);
|
||||
traverse_offset_bases(parentptr, parent_tinfo, self, f);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline bool register_instance_impl(void *ptr, instance *self) {
|
||||
get_internals().registered_instances.emplace(ptr, self);
|
||||
return true; // unused, but gives the same signature as the deregister func
|
||||
}
|
||||
inline bool deregister_instance_impl(void *ptr, instance *self) {
|
||||
auto ®istered_instances = get_internals().registered_instances;
|
||||
auto range = registered_instances.equal_range(ptr);
|
||||
for (auto it = range.first; it != range.second; ++it) {
|
||||
if (Py_TYPE(self) == Py_TYPE(it->second)) {
|
||||
registered_instances.erase(it);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline void register_instance(instance *self, void *valptr, const type_info *tinfo) {
|
||||
register_instance_impl(valptr, self);
|
||||
if (!tinfo->simple_ancestors)
|
||||
traverse_offset_bases(valptr, tinfo, self, register_instance_impl);
|
||||
}
|
||||
|
||||
inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) {
|
||||
bool ret = deregister_instance_impl(valptr, self);
|
||||
if (!tinfo->simple_ancestors)
|
||||
traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Instance creation function for all pybind11 types. It only allocates space for the C++ object
|
||||
/// (or multiple objects, for Python-side inheritance from multiple pybind11 types), but doesn't
|
||||
/// call the constructor -- an `__init__` function must do that (followed by an `init_instance`
|
||||
/// to set up the holder and register the instance).
|
||||
inline PyObject *make_new_instance(PyTypeObject *type, bool allocate_value /*= true (in cast.h)*/) {
|
||||
#if defined(PYPY_VERSION)
|
||||
// PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited
|
||||
// object is a a plain Python type (i.e. not derived from an extension type). Fix it.
|
||||
ssize_t instance_size = static_cast<ssize_t>(sizeof(instance));
|
||||
if (type->tp_basicsize < instance_size) {
|
||||
type->tp_basicsize = instance_size;
|
||||
}
|
||||
#endif
|
||||
PyObject *self = type->tp_alloc(type, 0);
|
||||
auto inst = reinterpret_cast<instance *>(self);
|
||||
// Allocate the value/holder internals:
|
||||
inst->allocate_layout();
|
||||
|
||||
inst->owned = true;
|
||||
// Allocate (if requested) the value pointers; otherwise leave them as nullptr
|
||||
if (allocate_value) {
|
||||
for (auto &v_h : values_and_holders(inst)) {
|
||||
void *&vptr = v_h.value_ptr();
|
||||
vptr = v_h.type->operator_new(v_h.type->type_size);
|
||||
}
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Instance creation function for all pybind11 types. It only allocates space for the
|
||||
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
|
||||
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
|
||||
return make_new_instance(type);
|
||||
}
|
||||
|
||||
/// An `__init__` function constructs the C++ object. Users should provide at least one
|
||||
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
|
||||
/// following default function will be used which simply throws an exception.
|
||||
extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) {
|
||||
PyTypeObject *type = Py_TYPE(self);
|
||||
std::string msg;
|
||||
#if defined(PYPY_VERSION)
|
||||
msg += handle((PyObject *) type).attr("__module__").cast<std::string>() + ".";
|
||||
#endif
|
||||
msg += type->tp_name;
|
||||
msg += ": No constructor defined!";
|
||||
PyErr_SetString(PyExc_TypeError, msg.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
inline void add_patient(PyObject *nurse, PyObject *patient) {
|
||||
auto &internals = get_internals();
|
||||
auto instance = reinterpret_cast<detail::instance *>(nurse);
|
||||
instance->has_patients = true;
|
||||
Py_INCREF(patient);
|
||||
internals.patients[nurse].push_back(patient);
|
||||
}
|
||||
|
||||
inline void clear_patients(PyObject *self) {
|
||||
auto instance = reinterpret_cast<detail::instance *>(self);
|
||||
auto &internals = get_internals();
|
||||
auto pos = internals.patients.find(self);
|
||||
assert(pos != internals.patients.end());
|
||||
// Clearing the patients can cause more Python code to run, which
|
||||
// can invalidate the iterator. Extract the vector of patients
|
||||
// from the unordered_map first.
|
||||
auto patients = std::move(pos->second);
|
||||
internals.patients.erase(pos);
|
||||
instance->has_patients = false;
|
||||
for (PyObject *&patient : patients)
|
||||
Py_CLEAR(patient);
|
||||
}
|
||||
|
||||
/// Clears all internal data from the instance and removes it from registered instances in
|
||||
/// preparation for deallocation.
|
||||
inline void clear_instance(PyObject *self) {
|
||||
auto instance = reinterpret_cast<detail::instance *>(self);
|
||||
|
||||
// Deallocate any values/holders, if present:
|
||||
for (auto &v_h : values_and_holders(instance)) {
|
||||
if (v_h) {
|
||||
|
||||
// We have to deregister before we call dealloc because, for virtual MI types, we still
|
||||
// need to be able to get the parent pointers.
|
||||
if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type))
|
||||
pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!");
|
||||
|
||||
if (instance->owned || v_h.holder_constructed())
|
||||
v_h.type->dealloc(v_h);
|
||||
}
|
||||
}
|
||||
// Deallocate the value/holder layout internals:
|
||||
instance->deallocate_layout();
|
||||
|
||||
if (instance->weakrefs)
|
||||
PyObject_ClearWeakRefs(self);
|
||||
|
||||
PyObject **dict_ptr = _PyObject_GetDictPtr(self);
|
||||
if (dict_ptr)
|
||||
Py_CLEAR(*dict_ptr);
|
||||
|
||||
if (instance->has_patients)
|
||||
clear_patients(self);
|
||||
}
|
||||
|
||||
/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc`
|
||||
/// to destroy the C++ object itself, while the rest is Python bookkeeping.
|
||||
extern "C" inline void pybind11_object_dealloc(PyObject *self) {
|
||||
clear_instance(self);
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
}
|
||||
|
||||
/** Create the type which can be used as a common base for all classes. This is
|
||||
needed in order to satisfy Python's requirements for multiple inheritance.
|
||||
Return value: New reference. */
|
||||
inline PyObject *make_object_base_type(PyTypeObject *metaclass) {
|
||||
constexpr auto *name = "pybind11_object";
|
||||
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
||||
|
||||
/* Danger zone: from now (and until PyType_Ready), make sure to
|
||||
issue no Python C API calls which could potentially invoke the
|
||||
garbage collector (the GC will call type_traverse(), which will in
|
||||
turn find the newly constructed type in an invalid state) */
|
||||
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
|
||||
if (!heap_type)
|
||||
pybind11_fail("make_object_base_type(): error allocating type!");
|
||||
|
||||
heap_type->ht_name = name_obj.inc_ref().ptr();
|
||||
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3
|
||||
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
||||
#endif
|
||||
|
||||
auto type = &heap_type->ht_type;
|
||||
type->tp_name = name;
|
||||
type->tp_base = type_incref(&PyBaseObject_Type);
|
||||
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
|
||||
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
||||
|
||||
type->tp_new = pybind11_object_new;
|
||||
type->tp_init = pybind11_object_init;
|
||||
type->tp_dealloc = pybind11_object_dealloc;
|
||||
|
||||
/* Support weak references (needed for the keep_alive feature) */
|
||||
type->tp_weaklistoffset = offsetof(instance, weakrefs);
|
||||
|
||||
if (PyType_Ready(type) < 0)
|
||||
pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string());
|
||||
|
||||
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
||||
|
||||
assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
|
||||
return (PyObject *) heap_type;
|
||||
}
|
||||
|
||||
/// dynamic_attr: Support for `d = instance.__dict__`.
|
||||
extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) {
|
||||
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
||||
if (!dict)
|
||||
dict = PyDict_New();
|
||||
Py_XINCREF(dict);
|
||||
return dict;
|
||||
}
|
||||
|
||||
/// dynamic_attr: Support for `instance.__dict__ = dict()`.
|
||||
extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) {
|
||||
if (!PyDict_Check(new_dict)) {
|
||||
PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'",
|
||||
Py_TYPE(new_dict)->tp_name);
|
||||
return -1;
|
||||
}
|
||||
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
||||
Py_INCREF(new_dict);
|
||||
Py_CLEAR(dict);
|
||||
dict = new_dict;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`.
|
||||
extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) {
|
||||
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
||||
Py_VISIT(dict);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// dynamic_attr: Allow the GC to clear the dictionary.
|
||||
extern "C" inline int pybind11_clear(PyObject *self) {
|
||||
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
||||
Py_CLEAR(dict);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Give instances of this type a `__dict__` and opt into garbage collection.
|
||||
inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
|
||||
auto type = &heap_type->ht_type;
|
||||
#if defined(PYPY_VERSION)
|
||||
pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are "
|
||||
"currently not supported in "
|
||||
"conjunction with PyPy!");
|
||||
#endif
|
||||
type->tp_flags |= Py_TPFLAGS_HAVE_GC;
|
||||
type->tp_dictoffset = type->tp_basicsize; // place dict at the end
|
||||
type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it
|
||||
type->tp_traverse = pybind11_traverse;
|
||||
type->tp_clear = pybind11_clear;
|
||||
|
||||
static PyGetSetDef getset[] = {
|
||||
{const_cast<char*>("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr},
|
||||
{nullptr, nullptr, nullptr, nullptr, nullptr}
|
||||
};
|
||||
type->tp_getset = getset;
|
||||
}
|
||||
|
||||
/// buffer_protocol: Fill in the view as specified by flags.
|
||||
extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
|
||||
// Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
|
||||
type_info *tinfo = nullptr;
|
||||
for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
|
||||
tinfo = get_type_info((PyTypeObject *) type.ptr());
|
||||
if (tinfo && tinfo->get_buffer)
|
||||
break;
|
||||
}
|
||||
if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) {
|
||||
if (view)
|
||||
view->obj = nullptr;
|
||||
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
|
||||
return -1;
|
||||
}
|
||||
std::memset(view, 0, sizeof(Py_buffer));
|
||||
buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data);
|
||||
view->obj = obj;
|
||||
view->ndim = 1;
|
||||
view->internal = info;
|
||||
view->buf = info->ptr;
|
||||
view->itemsize = info->itemsize;
|
||||
view->len = view->itemsize;
|
||||
for (auto s : info->shape)
|
||||
view->len *= s;
|
||||
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)
|
||||
view->format = const_cast<char *>(info->format.c_str());
|
||||
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
|
||||
view->ndim = (int) info->ndim;
|
||||
view->strides = &info->strides[0];
|
||||
view->shape = &info->shape[0];
|
||||
}
|
||||
Py_INCREF(view->obj);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// buffer_protocol: Release the resources of the buffer.
|
||||
extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) {
|
||||
delete (buffer_info *) view->internal;
|
||||
}
|
||||
|
||||
/// Give this type a buffer interface.
|
||||
inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) {
|
||||
heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER;
|
||||
#endif
|
||||
|
||||
heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer;
|
||||
heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer;
|
||||
}
|
||||
|
||||
/** Create a brand new Python type according to the `type_record` specification.
|
||||
Return value: New reference. */
|
||||
inline PyObject* make_new_python_type(const type_record &rec) {
|
||||
auto name = reinterpret_steal<object>(PYBIND11_FROM_STRING(rec.name));
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3
|
||||
auto ht_qualname = name;
|
||||
if (rec.scope && hasattr(rec.scope, "__qualname__")) {
|
||||
ht_qualname = reinterpret_steal<object>(
|
||||
PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr()));
|
||||
}
|
||||
#endif
|
||||
|
||||
object module;
|
||||
if (rec.scope) {
|
||||
if (hasattr(rec.scope, "__module__"))
|
||||
module = rec.scope.attr("__module__");
|
||||
else if (hasattr(rec.scope, "__name__"))
|
||||
module = rec.scope.attr("__name__");
|
||||
}
|
||||
|
||||
#if !defined(PYPY_VERSION)
|
||||
const auto full_name = module ? str(module).cast<std::string>() + "." + rec.name
|
||||
: std::string(rec.name);
|
||||
#else
|
||||
const auto full_name = std::string(rec.name);
|
||||
#endif
|
||||
|
||||
char *tp_doc = nullptr;
|
||||
if (rec.doc && options::show_user_defined_docstrings()) {
|
||||
/* Allocate memory for docstring (using PyObject_MALLOC, since
|
||||
Python will free this later on) */
|
||||
size_t size = strlen(rec.doc) + 1;
|
||||
tp_doc = (char *) PyObject_MALLOC(size);
|
||||
memcpy((void *) tp_doc, rec.doc, size);
|
||||
}
|
||||
|
||||
auto &internals = get_internals();
|
||||
auto bases = tuple(rec.bases);
|
||||
auto base = (bases.size() == 0) ? internals.instance_base
|
||||
: bases[0].ptr();
|
||||
|
||||
/* Danger zone: from now (and until PyType_Ready), make sure to
|
||||
issue no Python C API calls which could potentially invoke the
|
||||
garbage collector (the GC will call type_traverse(), which will in
|
||||
turn find the newly constructed type in an invalid state) */
|
||||
auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr()
|
||||
: internals.default_metaclass;
|
||||
|
||||
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
|
||||
if (!heap_type)
|
||||
pybind11_fail(std::string(rec.name) + ": Unable to create type object!");
|
||||
|
||||
heap_type->ht_name = name.release().ptr();
|
||||
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3
|
||||
heap_type->ht_qualname = ht_qualname.release().ptr();
|
||||
#endif
|
||||
|
||||
auto type = &heap_type->ht_type;
|
||||
type->tp_name = strdup(full_name.c_str());
|
||||
type->tp_doc = tp_doc;
|
||||
type->tp_base = type_incref((PyTypeObject *)base);
|
||||
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
|
||||
if (bases.size() > 0)
|
||||
type->tp_bases = bases.release().ptr();
|
||||
|
||||
/* Don't inherit base __init__ */
|
||||
type->tp_init = pybind11_object_init;
|
||||
|
||||
/* Supported protocols */
|
||||
type->tp_as_number = &heap_type->as_number;
|
||||
type->tp_as_sequence = &heap_type->as_sequence;
|
||||
type->tp_as_mapping = &heap_type->as_mapping;
|
||||
|
||||
/* Flags */
|
||||
type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
type->tp_flags |= Py_TPFLAGS_CHECKTYPES;
|
||||
#endif
|
||||
|
||||
if (rec.dynamic_attr)
|
||||
enable_dynamic_attributes(heap_type);
|
||||
|
||||
if (rec.buffer_protocol)
|
||||
enable_buffer_protocol(heap_type);
|
||||
|
||||
if (PyType_Ready(type) < 0)
|
||||
pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!");
|
||||
|
||||
assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)
|
||||
: !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
|
||||
|
||||
/* Register type with the parent scope */
|
||||
if (rec.scope)
|
||||
setattr(rec.scope, rec.name, (PyObject *) type);
|
||||
|
||||
if (module) // Needed by pydoc
|
||||
setattr((PyObject *) type, "__module__", module);
|
||||
|
||||
return (PyObject *) type;
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(pybind11)
|
|
@ -1,2 +0,0 @@
|
|||
#include "detail/common.h"
|
||||
#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
|
|
@ -1,65 +0,0 @@
|
|||
/*
|
||||
pybind11/complex.h: Complex number support
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include <complex>
|
||||
|
||||
/// glibc defines I as a macro which breaks things, e.g., boost template names
|
||||
#ifdef I
|
||||
# undef I
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr const char c = format_descriptor<T>::c;
|
||||
static constexpr const char value[3] = { 'Z', c, '\0' };
|
||||
static std::string format() { return std::string(value); }
|
||||
};
|
||||
|
||||
#ifndef PYBIND11_CPP17
|
||||
|
||||
template <typename T> constexpr const char format_descriptor<
|
||||
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
|
||||
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
|
||||
static constexpr bool value = true;
|
||||
static constexpr int index = is_fmt_numeric<T>::index + 3;
|
||||
};
|
||||
|
||||
template <typename T> class type_caster<std::complex<T>> {
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
if (!src)
|
||||
return false;
|
||||
if (!convert && !PyComplex_Check(src.ptr()))
|
||||
return false;
|
||||
Py_complex result = PyComplex_AsCComplex(src.ptr());
|
||||
if (result.real == -1.0 && PyErr_Occurred()) {
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
value = std::complex<T>((T) result.real, (T) result.imag);
|
||||
return true;
|
||||
}
|
||||
|
||||
static handle cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(std::complex<T>, _("complex"));
|
||||
};
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,185 +0,0 @@
|
|||
/*
|
||||
pybind11/descr.h: Helper type for concatenating type signatures
|
||||
either at runtime (C++11) or compile time (C++14)
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/* Concatenate type signatures at compile time using C++14 */
|
||||
#if defined(PYBIND11_CPP14) && !defined(_MSC_VER)
|
||||
#define PYBIND11_CONSTEXPR_DESCR
|
||||
|
||||
template <size_t Size1, size_t Size2> class descr {
|
||||
template <size_t Size1_, size_t Size2_> friend class descr;
|
||||
public:
|
||||
constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1])
|
||||
: descr(text, types,
|
||||
make_index_sequence<Size1>(),
|
||||
make_index_sequence<Size2>()) { }
|
||||
|
||||
constexpr const char *text() const { return m_text; }
|
||||
constexpr const std::type_info * const * types() const { return m_types; }
|
||||
|
||||
template <size_t OtherSize1, size_t OtherSize2>
|
||||
constexpr descr<Size1 + OtherSize1, Size2 + OtherSize2> operator+(const descr<OtherSize1, OtherSize2> &other) const {
|
||||
return concat(other,
|
||||
make_index_sequence<Size1>(),
|
||||
make_index_sequence<Size2>(),
|
||||
make_index_sequence<OtherSize1>(),
|
||||
make_index_sequence<OtherSize2>());
|
||||
}
|
||||
|
||||
protected:
|
||||
template <size_t... Indices1, size_t... Indices2>
|
||||
constexpr descr(
|
||||
char const (&text) [Size1+1],
|
||||
const std::type_info * const (&types) [Size2+1],
|
||||
index_sequence<Indices1...>, index_sequence<Indices2...>)
|
||||
: m_text{text[Indices1]..., '\0'},
|
||||
m_types{types[Indices2]..., nullptr } {}
|
||||
|
||||
template <size_t OtherSize1, size_t OtherSize2, size_t... Indices1,
|
||||
size_t... Indices2, size_t... OtherIndices1, size_t... OtherIndices2>
|
||||
constexpr descr<Size1 + OtherSize1, Size2 + OtherSize2>
|
||||
concat(const descr<OtherSize1, OtherSize2> &other,
|
||||
index_sequence<Indices1...>, index_sequence<Indices2...>,
|
||||
index_sequence<OtherIndices1...>, index_sequence<OtherIndices2...>) const {
|
||||
return descr<Size1 + OtherSize1, Size2 + OtherSize2>(
|
||||
{ m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' },
|
||||
{ m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr }
|
||||
);
|
||||
}
|
||||
|
||||
protected:
|
||||
char m_text[Size1 + 1];
|
||||
const std::type_info * m_types[Size2 + 1];
|
||||
};
|
||||
|
||||
template <size_t Size> constexpr descr<Size - 1, 0> _(char const(&text)[Size]) {
|
||||
return descr<Size - 1, 0>(text, { nullptr });
|
||||
}
|
||||
|
||||
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
|
||||
template <size_t...Digits> struct int_to_str<0, Digits...> {
|
||||
static constexpr auto digits = descr<sizeof...(Digits), 0>({ ('0' + Digits)..., '\0' }, { nullptr });
|
||||
};
|
||||
|
||||
// Ternary description (like std::conditional)
|
||||
template <bool B, size_t Size1, size_t Size2>
|
||||
constexpr enable_if_t<B, descr<Size1 - 1, 0>> _(char const(&text1)[Size1], char const(&)[Size2]) {
|
||||
return _(text1);
|
||||
}
|
||||
template <bool B, size_t Size1, size_t Size2>
|
||||
constexpr enable_if_t<!B, descr<Size2 - 1, 0>> _(char const(&)[Size1], char const(&text2)[Size2]) {
|
||||
return _(text2);
|
||||
}
|
||||
template <bool B, size_t SizeA1, size_t SizeA2, size_t SizeB1, size_t SizeB2>
|
||||
constexpr enable_if_t<B, descr<SizeA1, SizeA2>> _(descr<SizeA1, SizeA2> d, descr<SizeB1, SizeB2>) { return d; }
|
||||
template <bool B, size_t SizeA1, size_t SizeA2, size_t SizeB1, size_t SizeB2>
|
||||
constexpr enable_if_t<!B, descr<SizeB1, SizeB2>> _(descr<SizeA1, SizeA2>, descr<SizeB1, SizeB2> d) { return d; }
|
||||
|
||||
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
|
||||
return int_to_str<Size / 10, Size % 10>::digits;
|
||||
}
|
||||
|
||||
template <typename Type> constexpr descr<1, 1> _() {
|
||||
return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr });
|
||||
}
|
||||
|
||||
inline constexpr descr<0, 0> concat() { return _(""); }
|
||||
template <size_t Size1, size_t Size2, typename... Args> auto constexpr concat(descr<Size1, Size2> descr) { return descr; }
|
||||
template <size_t Size1, size_t Size2, typename... Args> auto constexpr concat(descr<Size1, Size2> descr, Args&&... args) { return descr + _(", ") + concat(args...); }
|
||||
template <size_t Size1, size_t Size2> auto constexpr type_descr(descr<Size1, Size2> descr) { return _("{") + descr + _("}"); }
|
||||
|
||||
#define PYBIND11_DESCR constexpr auto
|
||||
|
||||
#else /* Simpler C++11 implementation based on run-time memory allocation and copying */
|
||||
|
||||
class descr {
|
||||
public:
|
||||
PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) {
|
||||
size_t nChars = len(text), nTypes = len(types);
|
||||
m_text = new char[nChars];
|
||||
m_types = new const std::type_info *[nTypes];
|
||||
memcpy(m_text, text, nChars * sizeof(char));
|
||||
memcpy(m_types, types, nTypes * sizeof(const std::type_info *));
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE descr operator+(descr &&d2) && {
|
||||
descr r;
|
||||
|
||||
size_t nChars1 = len(m_text), nTypes1 = len(m_types);
|
||||
size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types);
|
||||
|
||||
r.m_text = new char[nChars1 + nChars2 - 1];
|
||||
r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1];
|
||||
memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char));
|
||||
memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char));
|
||||
memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *));
|
||||
memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *));
|
||||
|
||||
delete[] m_text; delete[] m_types;
|
||||
delete[] d2.m_text; delete[] d2.m_types;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
char *text() { return m_text; }
|
||||
const std::type_info * * types() { return m_types; }
|
||||
|
||||
protected:
|
||||
PYBIND11_NOINLINE descr() { }
|
||||
|
||||
template <typename T> static size_t len(const T *ptr) { // return length including null termination
|
||||
const T *it = ptr;
|
||||
while (*it++ != (T) 0)
|
||||
;
|
||||
return static_cast<size_t>(it - ptr);
|
||||
}
|
||||
|
||||
const std::type_info **m_types = nullptr;
|
||||
char *m_text = nullptr;
|
||||
};
|
||||
|
||||
/* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */
|
||||
|
||||
PYBIND11_NOINLINE inline descr _(const char *text) {
|
||||
const std::type_info *types[1] = { nullptr };
|
||||
return descr(text, types);
|
||||
}
|
||||
|
||||
template <bool B> PYBIND11_NOINLINE enable_if_t<B, descr> _(const char *text1, const char *) { return _(text1); }
|
||||
template <bool B> PYBIND11_NOINLINE enable_if_t<!B, descr> _(char const *, const char *text2) { return _(text2); }
|
||||
template <bool B> PYBIND11_NOINLINE enable_if_t<B, descr> _(descr d, descr) { return d; }
|
||||
template <bool B> PYBIND11_NOINLINE enable_if_t<!B, descr> _(descr, descr d) { return d; }
|
||||
|
||||
template <typename Type> PYBIND11_NOINLINE descr _() {
|
||||
const std::type_info *types[2] = { &typeid(Type), nullptr };
|
||||
return descr("%", types);
|
||||
}
|
||||
|
||||
template <size_t Size> PYBIND11_NOINLINE descr _() {
|
||||
const std::type_info *types[1] = { nullptr };
|
||||
return descr(std::to_string(Size).c_str(), types);
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE inline descr concat() { return _(""); }
|
||||
PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; }
|
||||
template <typename... Args> PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward<Args>(args)...); }
|
||||
PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); }
|
||||
|
||||
#define PYBIND11_DESCR ::pybind11::detail::descr
|
||||
#endif
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(pybind11)
|
|
@ -1,622 +0,0 @@
|
|||
/*
|
||||
pybind11/detail/class.h: Python C API implementation details for py::class_
|
||||
|
||||
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../attr.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
#if PY_VERSION_HEX >= 0x03030000
|
||||
# define PYBIND11_BUILTIN_QUALNAME
|
||||
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj)
|
||||
#else
|
||||
// In pre-3.3 Python, we still set __qualname__ so that we can produce reliable function type
|
||||
// signatures; in 3.3+ this macro expands to nothing:
|
||||
# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj)
|
||||
#endif
|
||||
|
||||
inline PyTypeObject *type_incref(PyTypeObject *type) {
|
||||
Py_INCREF(type);
|
||||
return type;
|
||||
}
|
||||
|
||||
#if !defined(PYPY_VERSION)
|
||||
|
||||
/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance.
|
||||
extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) {
|
||||
return PyProperty_Type.tp_descr_get(self, cls, cls);
|
||||
}
|
||||
|
||||
/// `pybind11_static_property.__set__()`: Just like the above `__get__()`.
|
||||
extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) {
|
||||
PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj);
|
||||
return PyProperty_Type.tp_descr_set(self, cls, value);
|
||||
}
|
||||
|
||||
/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()`
|
||||
methods are modified to always use the object type instead of a concrete instance.
|
||||
Return value: New reference. */
|
||||
inline PyTypeObject *make_static_property_type() {
|
||||
constexpr auto *name = "pybind11_static_property";
|
||||
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
||||
|
||||
/* Danger zone: from now (and until PyType_Ready), make sure to
|
||||
issue no Python C API calls which could potentially invoke the
|
||||
garbage collector (the GC will call type_traverse(), which will in
|
||||
turn find the newly constructed type in an invalid state) */
|
||||
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
|
||||
if (!heap_type)
|
||||
pybind11_fail("make_static_property_type(): error allocating type!");
|
||||
|
||||
heap_type->ht_name = name_obj.inc_ref().ptr();
|
||||
#ifdef PYBIND11_BUILTIN_QUALNAME
|
||||
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
||||
#endif
|
||||
|
||||
auto type = &heap_type->ht_type;
|
||||
type->tp_name = name;
|
||||
type->tp_base = type_incref(&PyProperty_Type);
|
||||
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
||||
type->tp_descr_get = pybind11_static_get;
|
||||
type->tp_descr_set = pybind11_static_set;
|
||||
|
||||
if (PyType_Ready(type) < 0)
|
||||
pybind11_fail("make_static_property_type(): failure in PyType_Ready()!");
|
||||
|
||||
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
||||
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
#else // PYPY
|
||||
|
||||
/** PyPy has some issues with the above C API, so we evaluate Python code instead.
|
||||
This function will only be called once so performance isn't really a concern.
|
||||
Return value: New reference. */
|
||||
inline PyTypeObject *make_static_property_type() {
|
||||
auto d = dict();
|
||||
PyObject *result = PyRun_String(R"(\
|
||||
class pybind11_static_property(property):
|
||||
def __get__(self, obj, cls):
|
||||
return property.__get__(self, cls, cls)
|
||||
|
||||
def __set__(self, obj, value):
|
||||
cls = obj if isinstance(obj, type) else type(obj)
|
||||
property.__set__(self, cls, value)
|
||||
)", Py_file_input, d.ptr(), d.ptr()
|
||||
);
|
||||
if (result == nullptr)
|
||||
throw error_already_set();
|
||||
Py_DECREF(result);
|
||||
return (PyTypeObject *) d["pybind11_static_property"].cast<object>().release().ptr();
|
||||
}
|
||||
|
||||
#endif // PYPY
|
||||
|
||||
/** Types with static properties need to handle `Type.static_prop = x` in a specific way.
|
||||
By default, Python replaces the `static_property` itself, but for wrapped C++ types
|
||||
we need to call `static_property.__set__()` in order to propagate the new value to
|
||||
the underlying C++ data structure. */
|
||||
extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) {
|
||||
// Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw
|
||||
// descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`).
|
||||
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
|
||||
|
||||
// The following assignment combinations are possible:
|
||||
// 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)`
|
||||
// 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop`
|
||||
// 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment
|
||||
const auto static_prop = (PyObject *) get_internals().static_property_type;
|
||||
const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop)
|
||||
&& !PyObject_IsInstance(value, static_prop);
|
||||
if (call_descr_set) {
|
||||
// Call `static_property.__set__()` instead of replacing the `static_property`.
|
||||
#if !defined(PYPY_VERSION)
|
||||
return Py_TYPE(descr)->tp_descr_set(descr, obj, value);
|
||||
#else
|
||||
if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) {
|
||||
Py_DECREF(result);
|
||||
return 0;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
// Replace existing attribute.
|
||||
return PyType_Type.tp_setattro(obj, name, value);
|
||||
}
|
||||
}
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
/**
|
||||
* Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing
|
||||
* methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function,
|
||||
* when called on a class, or a PyMethod, when called on an instance. Override that behaviour here
|
||||
* to do a special case bypass for PyInstanceMethod_Types.
|
||||
*/
|
||||
extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) {
|
||||
PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name);
|
||||
if (descr && PyInstanceMethod_Check(descr)) {
|
||||
Py_INCREF(descr);
|
||||
return descr;
|
||||
}
|
||||
else {
|
||||
return PyType_Type.tp_getattro(obj, name);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/** This metaclass is assigned by default to all pybind11 types and is required in order
|
||||
for static properties to function correctly. Users may override this using `py::metaclass`.
|
||||
Return value: New reference. */
|
||||
inline PyTypeObject* make_default_metaclass() {
|
||||
constexpr auto *name = "pybind11_type";
|
||||
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
||||
|
||||
/* Danger zone: from now (and until PyType_Ready), make sure to
|
||||
issue no Python C API calls which could potentially invoke the
|
||||
garbage collector (the GC will call type_traverse(), which will in
|
||||
turn find the newly constructed type in an invalid state) */
|
||||
auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0);
|
||||
if (!heap_type)
|
||||
pybind11_fail("make_default_metaclass(): error allocating metaclass!");
|
||||
|
||||
heap_type->ht_name = name_obj.inc_ref().ptr();
|
||||
#ifdef PYBIND11_BUILTIN_QUALNAME
|
||||
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
||||
#endif
|
||||
|
||||
auto type = &heap_type->ht_type;
|
||||
type->tp_name = name;
|
||||
type->tp_base = type_incref(&PyType_Type);
|
||||
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
||||
|
||||
type->tp_setattro = pybind11_meta_setattro;
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
type->tp_getattro = pybind11_meta_getattro;
|
||||
#endif
|
||||
|
||||
if (PyType_Ready(type) < 0)
|
||||
pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!");
|
||||
|
||||
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
||||
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
/// For multiple inheritance types we need to recursively register/deregister base pointers for any
|
||||
/// base classes with pointers that are difference from the instance value pointer so that we can
|
||||
/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs.
|
||||
inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self,
|
||||
bool (*f)(void * /*parentptr*/, instance * /*self*/)) {
|
||||
for (handle h : reinterpret_borrow<tuple>(tinfo->type->tp_bases)) {
|
||||
if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) {
|
||||
for (auto &c : parent_tinfo->implicit_casts) {
|
||||
if (c.first == tinfo->cpptype) {
|
||||
auto *parentptr = c.second(valueptr);
|
||||
if (parentptr != valueptr)
|
||||
f(parentptr, self);
|
||||
traverse_offset_bases(parentptr, parent_tinfo, self, f);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline bool register_instance_impl(void *ptr, instance *self) {
|
||||
get_internals().registered_instances.emplace(ptr, self);
|
||||
return true; // unused, but gives the same signature as the deregister func
|
||||
}
|
||||
inline bool deregister_instance_impl(void *ptr, instance *self) {
|
||||
auto ®istered_instances = get_internals().registered_instances;
|
||||
auto range = registered_instances.equal_range(ptr);
|
||||
for (auto it = range.first; it != range.second; ++it) {
|
||||
if (Py_TYPE(self) == Py_TYPE(it->second)) {
|
||||
registered_instances.erase(it);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline void register_instance(instance *self, void *valptr, const type_info *tinfo) {
|
||||
register_instance_impl(valptr, self);
|
||||
if (!tinfo->simple_ancestors)
|
||||
traverse_offset_bases(valptr, tinfo, self, register_instance_impl);
|
||||
}
|
||||
|
||||
inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) {
|
||||
bool ret = deregister_instance_impl(valptr, self);
|
||||
if (!tinfo->simple_ancestors)
|
||||
traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Instance creation function for all pybind11 types. It allocates the internal instance layout for
|
||||
/// holding C++ objects and holders. Allocation is done lazily (the first time the instance is cast
|
||||
/// to a reference or pointer), and initialization is done by an `__init__` function.
|
||||
inline PyObject *make_new_instance(PyTypeObject *type) {
|
||||
#if defined(PYPY_VERSION)
|
||||
// PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited
|
||||
// object is a a plain Python type (i.e. not derived from an extension type). Fix it.
|
||||
ssize_t instance_size = static_cast<ssize_t>(sizeof(instance));
|
||||
if (type->tp_basicsize < instance_size) {
|
||||
type->tp_basicsize = instance_size;
|
||||
}
|
||||
#endif
|
||||
PyObject *self = type->tp_alloc(type, 0);
|
||||
auto inst = reinterpret_cast<instance *>(self);
|
||||
// Allocate the value/holder internals:
|
||||
inst->allocate_layout();
|
||||
|
||||
inst->owned = true;
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Instance creation function for all pybind11 types. It only allocates space for the
|
||||
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
|
||||
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
|
||||
return make_new_instance(type);
|
||||
}
|
||||
|
||||
/// An `__init__` function constructs the C++ object. Users should provide at least one
|
||||
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the
|
||||
/// following default function will be used which simply throws an exception.
|
||||
extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) {
|
||||
PyTypeObject *type = Py_TYPE(self);
|
||||
std::string msg;
|
||||
#if defined(PYPY_VERSION)
|
||||
msg += handle((PyObject *) type).attr("__module__").cast<std::string>() + ".";
|
||||
#endif
|
||||
msg += type->tp_name;
|
||||
msg += ": No constructor defined!";
|
||||
PyErr_SetString(PyExc_TypeError, msg.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
inline void add_patient(PyObject *nurse, PyObject *patient) {
|
||||
auto &internals = get_internals();
|
||||
auto instance = reinterpret_cast<detail::instance *>(nurse);
|
||||
instance->has_patients = true;
|
||||
Py_INCREF(patient);
|
||||
internals.patients[nurse].push_back(patient);
|
||||
}
|
||||
|
||||
inline void clear_patients(PyObject *self) {
|
||||
auto instance = reinterpret_cast<detail::instance *>(self);
|
||||
auto &internals = get_internals();
|
||||
auto pos = internals.patients.find(self);
|
||||
assert(pos != internals.patients.end());
|
||||
// Clearing the patients can cause more Python code to run, which
|
||||
// can invalidate the iterator. Extract the vector of patients
|
||||
// from the unordered_map first.
|
||||
auto patients = std::move(pos->second);
|
||||
internals.patients.erase(pos);
|
||||
instance->has_patients = false;
|
||||
for (PyObject *&patient : patients)
|
||||
Py_CLEAR(patient);
|
||||
}
|
||||
|
||||
/// Clears all internal data from the instance and removes it from registered instances in
|
||||
/// preparation for deallocation.
|
||||
inline void clear_instance(PyObject *self) {
|
||||
auto instance = reinterpret_cast<detail::instance *>(self);
|
||||
|
||||
// Deallocate any values/holders, if present:
|
||||
for (auto &v_h : values_and_holders(instance)) {
|
||||
if (v_h) {
|
||||
|
||||
// We have to deregister before we call dealloc because, for virtual MI types, we still
|
||||
// need to be able to get the parent pointers.
|
||||
if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type))
|
||||
pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!");
|
||||
|
||||
if (instance->owned || v_h.holder_constructed())
|
||||
v_h.type->dealloc(v_h);
|
||||
}
|
||||
}
|
||||
// Deallocate the value/holder layout internals:
|
||||
instance->deallocate_layout();
|
||||
|
||||
if (instance->weakrefs)
|
||||
PyObject_ClearWeakRefs(self);
|
||||
|
||||
PyObject **dict_ptr = _PyObject_GetDictPtr(self);
|
||||
if (dict_ptr)
|
||||
Py_CLEAR(*dict_ptr);
|
||||
|
||||
if (instance->has_patients)
|
||||
clear_patients(self);
|
||||
}
|
||||
|
||||
/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc`
|
||||
/// to destroy the C++ object itself, while the rest is Python bookkeeping.
|
||||
extern "C" inline void pybind11_object_dealloc(PyObject *self) {
|
||||
clear_instance(self);
|
||||
|
||||
auto type = Py_TYPE(self);
|
||||
type->tp_free(self);
|
||||
|
||||
// `type->tp_dealloc != pybind11_object_dealloc` means that we're being called
|
||||
// as part of a derived type's dealloc, in which case we're not allowed to decref
|
||||
// the type here. For cross-module compatibility, we shouldn't compare directly
|
||||
// with `pybind11_object_dealloc`, but with the common one stashed in internals.
|
||||
auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base;
|
||||
if (type->tp_dealloc == pybind11_object_type->tp_dealloc)
|
||||
Py_DECREF(type);
|
||||
}
|
||||
|
||||
/** Create the type which can be used as a common base for all classes. This is
|
||||
needed in order to satisfy Python's requirements for multiple inheritance.
|
||||
Return value: New reference. */
|
||||
inline PyObject *make_object_base_type(PyTypeObject *metaclass) {
|
||||
constexpr auto *name = "pybind11_object";
|
||||
auto name_obj = reinterpret_steal<object>(PYBIND11_FROM_STRING(name));
|
||||
|
||||
/* Danger zone: from now (and until PyType_Ready), make sure to
|
||||
issue no Python C API calls which could potentially invoke the
|
||||
garbage collector (the GC will call type_traverse(), which will in
|
||||
turn find the newly constructed type in an invalid state) */
|
||||
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
|
||||
if (!heap_type)
|
||||
pybind11_fail("make_object_base_type(): error allocating type!");
|
||||
|
||||
heap_type->ht_name = name_obj.inc_ref().ptr();
|
||||
#ifdef PYBIND11_BUILTIN_QUALNAME
|
||||
heap_type->ht_qualname = name_obj.inc_ref().ptr();
|
||||
#endif
|
||||
|
||||
auto type = &heap_type->ht_type;
|
||||
type->tp_name = name;
|
||||
type->tp_base = type_incref(&PyBaseObject_Type);
|
||||
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
|
||||
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
||||
|
||||
type->tp_new = pybind11_object_new;
|
||||
type->tp_init = pybind11_object_init;
|
||||
type->tp_dealloc = pybind11_object_dealloc;
|
||||
|
||||
/* Support weak references (needed for the keep_alive feature) */
|
||||
type->tp_weaklistoffset = offsetof(instance, weakrefs);
|
||||
|
||||
if (PyType_Ready(type) < 0)
|
||||
pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string());
|
||||
|
||||
setattr((PyObject *) type, "__module__", str("pybind11_builtins"));
|
||||
PYBIND11_SET_OLDPY_QUALNAME(type, name_obj);
|
||||
|
||||
assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
|
||||
return (PyObject *) heap_type;
|
||||
}
|
||||
|
||||
/// dynamic_attr: Support for `d = instance.__dict__`.
|
||||
extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) {
|
||||
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
||||
if (!dict)
|
||||
dict = PyDict_New();
|
||||
Py_XINCREF(dict);
|
||||
return dict;
|
||||
}
|
||||
|
||||
/// dynamic_attr: Support for `instance.__dict__ = dict()`.
|
||||
extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) {
|
||||
if (!PyDict_Check(new_dict)) {
|
||||
PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'",
|
||||
Py_TYPE(new_dict)->tp_name);
|
||||
return -1;
|
||||
}
|
||||
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
||||
Py_INCREF(new_dict);
|
||||
Py_CLEAR(dict);
|
||||
dict = new_dict;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`.
|
||||
extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) {
|
||||
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
||||
Py_VISIT(dict);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// dynamic_attr: Allow the GC to clear the dictionary.
|
||||
extern "C" inline int pybind11_clear(PyObject *self) {
|
||||
PyObject *&dict = *_PyObject_GetDictPtr(self);
|
||||
Py_CLEAR(dict);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Give instances of this type a `__dict__` and opt into garbage collection.
|
||||
inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) {
|
||||
auto type = &heap_type->ht_type;
|
||||
#if defined(PYPY_VERSION)
|
||||
pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are "
|
||||
"currently not supported in "
|
||||
"conjunction with PyPy!");
|
||||
#endif
|
||||
type->tp_flags |= Py_TPFLAGS_HAVE_GC;
|
||||
type->tp_dictoffset = type->tp_basicsize; // place dict at the end
|
||||
type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it
|
||||
type->tp_traverse = pybind11_traverse;
|
||||
type->tp_clear = pybind11_clear;
|
||||
|
||||
static PyGetSetDef getset[] = {
|
||||
{const_cast<char*>("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr},
|
||||
{nullptr, nullptr, nullptr, nullptr, nullptr}
|
||||
};
|
||||
type->tp_getset = getset;
|
||||
}
|
||||
|
||||
/// buffer_protocol: Fill in the view as specified by flags.
|
||||
extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
|
||||
// Look for a `get_buffer` implementation in this type's info or any bases (following MRO).
|
||||
type_info *tinfo = nullptr;
|
||||
for (auto type : reinterpret_borrow<tuple>(Py_TYPE(obj)->tp_mro)) {
|
||||
tinfo = get_type_info((PyTypeObject *) type.ptr());
|
||||
if (tinfo && tinfo->get_buffer)
|
||||
break;
|
||||
}
|
||||
if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) {
|
||||
if (view)
|
||||
view->obj = nullptr;
|
||||
PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error");
|
||||
return -1;
|
||||
}
|
||||
std::memset(view, 0, sizeof(Py_buffer));
|
||||
buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data);
|
||||
view->obj = obj;
|
||||
view->ndim = 1;
|
||||
view->internal = info;
|
||||
view->buf = info->ptr;
|
||||
view->itemsize = info->itemsize;
|
||||
view->len = view->itemsize;
|
||||
for (auto s : info->shape)
|
||||
view->len *= s;
|
||||
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)
|
||||
view->format = const_cast<char *>(info->format.c_str());
|
||||
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
|
||||
view->ndim = (int) info->ndim;
|
||||
view->strides = &info->strides[0];
|
||||
view->shape = &info->shape[0];
|
||||
}
|
||||
Py_INCREF(view->obj);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// buffer_protocol: Release the resources of the buffer.
|
||||
extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) {
|
||||
delete (buffer_info *) view->internal;
|
||||
}
|
||||
|
||||
/// Give this type a buffer interface.
|
||||
inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) {
|
||||
heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER;
|
||||
#endif
|
||||
|
||||
heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer;
|
||||
heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer;
|
||||
}
|
||||
|
||||
/** Create a brand new Python type according to the `type_record` specification.
|
||||
Return value: New reference. */
|
||||
inline PyObject* make_new_python_type(const type_record &rec) {
|
||||
auto name = reinterpret_steal<object>(PYBIND11_FROM_STRING(rec.name));
|
||||
|
||||
auto qualname = name;
|
||||
if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
qualname = reinterpret_steal<object>(
|
||||
PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr()));
|
||||
#else
|
||||
qualname = str(rec.scope.attr("__qualname__").cast<std::string>() + "." + rec.name);
|
||||
#endif
|
||||
}
|
||||
|
||||
object module;
|
||||
if (rec.scope) {
|
||||
if (hasattr(rec.scope, "__module__"))
|
||||
module = rec.scope.attr("__module__");
|
||||
else if (hasattr(rec.scope, "__name__"))
|
||||
module = rec.scope.attr("__name__");
|
||||
}
|
||||
|
||||
auto full_name = c_str(
|
||||
#if !defined(PYPY_VERSION)
|
||||
module ? str(module).cast<std::string>() + "." + rec.name :
|
||||
#endif
|
||||
rec.name);
|
||||
|
||||
char *tp_doc = nullptr;
|
||||
if (rec.doc && options::show_user_defined_docstrings()) {
|
||||
/* Allocate memory for docstring (using PyObject_MALLOC, since
|
||||
Python will free this later on) */
|
||||
size_t size = strlen(rec.doc) + 1;
|
||||
tp_doc = (char *) PyObject_MALLOC(size);
|
||||
memcpy((void *) tp_doc, rec.doc, size);
|
||||
}
|
||||
|
||||
auto &internals = get_internals();
|
||||
auto bases = tuple(rec.bases);
|
||||
auto base = (bases.size() == 0) ? internals.instance_base
|
||||
: bases[0].ptr();
|
||||
|
||||
/* Danger zone: from now (and until PyType_Ready), make sure to
|
||||
issue no Python C API calls which could potentially invoke the
|
||||
garbage collector (the GC will call type_traverse(), which will in
|
||||
turn find the newly constructed type in an invalid state) */
|
||||
auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr()
|
||||
: internals.default_metaclass;
|
||||
|
||||
auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0);
|
||||
if (!heap_type)
|
||||
pybind11_fail(std::string(rec.name) + ": Unable to create type object!");
|
||||
|
||||
heap_type->ht_name = name.release().ptr();
|
||||
#ifdef PYBIND11_BUILTIN_QUALNAME
|
||||
heap_type->ht_qualname = qualname.inc_ref().ptr();
|
||||
#endif
|
||||
|
||||
auto type = &heap_type->ht_type;
|
||||
type->tp_name = full_name;
|
||||
type->tp_doc = tp_doc;
|
||||
type->tp_base = type_incref((PyTypeObject *)base);
|
||||
type->tp_basicsize = static_cast<ssize_t>(sizeof(instance));
|
||||
if (bases.size() > 0)
|
||||
type->tp_bases = bases.release().ptr();
|
||||
|
||||
/* Don't inherit base __init__ */
|
||||
type->tp_init = pybind11_object_init;
|
||||
|
||||
/* Supported protocols */
|
||||
type->tp_as_number = &heap_type->as_number;
|
||||
type->tp_as_sequence = &heap_type->as_sequence;
|
||||
type->tp_as_mapping = &heap_type->as_mapping;
|
||||
|
||||
/* Flags */
|
||||
type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
type->tp_flags |= Py_TPFLAGS_CHECKTYPES;
|
||||
#endif
|
||||
|
||||
if (rec.dynamic_attr)
|
||||
enable_dynamic_attributes(heap_type);
|
||||
|
||||
if (rec.buffer_protocol)
|
||||
enable_buffer_protocol(heap_type);
|
||||
|
||||
if (PyType_Ready(type) < 0)
|
||||
pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!");
|
||||
|
||||
assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)
|
||||
: !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
|
||||
|
||||
/* Register type with the parent scope */
|
||||
if (rec.scope)
|
||||
setattr(rec.scope, rec.name, (PyObject *) type);
|
||||
else
|
||||
Py_INCREF(type); // Keep it alive forever (reference leak)
|
||||
|
||||
if (module) // Needed by pydoc
|
||||
setattr((PyObject *) type, "__module__", module);
|
||||
|
||||
PYBIND11_SET_OLDPY_QUALNAME(type, qualname);
|
||||
|
||||
return (PyObject *) type;
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,807 +0,0 @@
|
|||
/*
|
||||
pybind11/detail/common.h -- Basic macros
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(NAMESPACE_BEGIN)
|
||||
# define NAMESPACE_BEGIN(name) namespace name {
|
||||
#endif
|
||||
#if !defined(NAMESPACE_END)
|
||||
# define NAMESPACE_END(name) }
|
||||
#endif
|
||||
|
||||
// Robust support for some features and loading modules compiled against different pybind versions
|
||||
// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on
|
||||
// the main `pybind11` namespace.
|
||||
#if !defined(PYBIND11_NAMESPACE)
|
||||
# ifdef __GNUG__
|
||||
# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden")))
|
||||
# else
|
||||
# define PYBIND11_NAMESPACE pybind11
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER)
|
||||
# if __cplusplus >= 201402L
|
||||
# define PYBIND11_CPP14
|
||||
# if __cplusplus >= 201703L
|
||||
# define PYBIND11_CPP17
|
||||
# endif
|
||||
# endif
|
||||
#elif defined(_MSC_VER) && __cplusplus == 199711L
|
||||
// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented)
|
||||
// Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 or newer
|
||||
# if _MSVC_LANG >= 201402L
|
||||
# define PYBIND11_CPP14
|
||||
# if _MSVC_LANG > 201402L && _MSC_VER >= 1910
|
||||
# define PYBIND11_CPP17
|
||||
# endif
|
||||
# endif
|
||||
#endif
|
||||
|
||||
// Compiler version assertions
|
||||
#if defined(__INTEL_COMPILER)
|
||||
# if __INTEL_COMPILER < 1700
|
||||
# error pybind11 requires Intel C++ compiler v17 or newer
|
||||
# endif
|
||||
#elif defined(__clang__) && !defined(__apple_build_version__)
|
||||
# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3)
|
||||
# error pybind11 requires clang 3.3 or newer
|
||||
# endif
|
||||
#elif defined(__clang__)
|
||||
// Apple changes clang version macros to its Xcode version; the first Xcode release based on
|
||||
// (upstream) clang 3.3 was Xcode 5:
|
||||
# if __clang_major__ < 5
|
||||
# error pybind11 requires Xcode/clang 5.0 or newer
|
||||
# endif
|
||||
#elif defined(__GNUG__)
|
||||
# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8)
|
||||
# error pybind11 requires gcc 4.8 or newer
|
||||
# endif
|
||||
#elif defined(_MSC_VER)
|
||||
// Pybind hits various compiler bugs in 2015u2 and earlier, and also makes use of some stl features
|
||||
// (e.g. std::negation) added in 2015u3:
|
||||
# if _MSC_FULL_VER < 190024210
|
||||
# error pybind11 requires MSVC 2015 update 3 or newer
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#if !defined(PYBIND11_EXPORT)
|
||||
# if defined(WIN32) || defined(_WIN32)
|
||||
# define PYBIND11_EXPORT __declspec(dllexport)
|
||||
# else
|
||||
# define PYBIND11_EXPORT __attribute__ ((visibility("default")))
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
# define PYBIND11_NOINLINE __declspec(noinline)
|
||||
#else
|
||||
# define PYBIND11_NOINLINE __attribute__ ((noinline))
|
||||
#endif
|
||||
|
||||
#if defined(PYBIND11_CPP14)
|
||||
# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]]
|
||||
#else
|
||||
# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason)))
|
||||
#endif
|
||||
|
||||
#define PYBIND11_VERSION_MAJOR 2
|
||||
#define PYBIND11_VERSION_MINOR 3
|
||||
#define PYBIND11_VERSION_PATCH dev0
|
||||
|
||||
/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode
|
||||
#if defined(_MSC_VER)
|
||||
# if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 4)
|
||||
# define HAVE_ROUND 1
|
||||
# endif
|
||||
# pragma warning(push)
|
||||
# pragma warning(disable: 4510 4610 4512 4005)
|
||||
# if defined(_DEBUG)
|
||||
# define PYBIND11_DEBUG_MARKER
|
||||
# undef _DEBUG
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#include <Python.h>
|
||||
#include <frameobject.h>
|
||||
#include <pythread.h>
|
||||
|
||||
#if defined(_WIN32) && (defined(min) || defined(max))
|
||||
# error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
|
||||
#endif
|
||||
|
||||
#if defined(isalnum)
|
||||
# undef isalnum
|
||||
# undef isalpha
|
||||
# undef islower
|
||||
# undef isspace
|
||||
# undef isupper
|
||||
# undef tolower
|
||||
# undef toupper
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
# if defined(PYBIND11_DEBUG_MARKER)
|
||||
# define _DEBUG
|
||||
# undef PYBIND11_DEBUG_MARKER
|
||||
# endif
|
||||
# pragma warning(pop)
|
||||
#endif
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <forward_list>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <typeindex>
|
||||
#include <type_traits>
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions
|
||||
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr)
|
||||
#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check
|
||||
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION
|
||||
#define PYBIND11_BYTES_CHECK PyBytes_Check
|
||||
#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString
|
||||
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize
|
||||
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize
|
||||
#define PYBIND11_BYTES_AS_STRING PyBytes_AsString
|
||||
#define PYBIND11_BYTES_SIZE PyBytes_Size
|
||||
#define PYBIND11_LONG_CHECK(o) PyLong_Check(o)
|
||||
#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o)
|
||||
#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o)
|
||||
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o)
|
||||
#define PYBIND11_BYTES_NAME "bytes"
|
||||
#define PYBIND11_STRING_NAME "str"
|
||||
#define PYBIND11_SLICE_OBJECT PyObject
|
||||
#define PYBIND11_FROM_STRING PyUnicode_FromString
|
||||
#define PYBIND11_STR_TYPE ::pybind11::str
|
||||
#define PYBIND11_BOOL_ATTR "__bool__"
|
||||
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
|
||||
#define PYBIND11_PLUGIN_IMPL(name) \
|
||||
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
|
||||
|
||||
#else
|
||||
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
|
||||
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
|
||||
#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyMethod_GET_FUNCTION
|
||||
#define PYBIND11_BYTES_CHECK PyString_Check
|
||||
#define PYBIND11_BYTES_FROM_STRING PyString_FromString
|
||||
#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyString_FromStringAndSize
|
||||
#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyString_AsStringAndSize
|
||||
#define PYBIND11_BYTES_AS_STRING PyString_AsString
|
||||
#define PYBIND11_BYTES_SIZE PyString_Size
|
||||
#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o))
|
||||
#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o))
|
||||
#define PYBIND11_LONG_FROM_SIGNED(o) PyInt_FromSsize_t((ssize_t) o) // Returns long if needed.
|
||||
#define PYBIND11_LONG_FROM_UNSIGNED(o) PyInt_FromSize_t((size_t) o) // Returns long if needed.
|
||||
#define PYBIND11_BYTES_NAME "str"
|
||||
#define PYBIND11_STRING_NAME "unicode"
|
||||
#define PYBIND11_SLICE_OBJECT PySliceObject
|
||||
#define PYBIND11_FROM_STRING PyString_FromString
|
||||
#define PYBIND11_STR_TYPE ::pybind11::bytes
|
||||
#define PYBIND11_BOOL_ATTR "__nonzero__"
|
||||
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
|
||||
#define PYBIND11_PLUGIN_IMPL(name) \
|
||||
static PyObject *pybind11_init_wrapper(); \
|
||||
extern "C" PYBIND11_EXPORT void init##name() { \
|
||||
(void)pybind11_init_wrapper(); \
|
||||
} \
|
||||
PyObject *pybind11_init_wrapper()
|
||||
#endif
|
||||
|
||||
#if PY_VERSION_HEX >= 0x03050000 && PY_VERSION_HEX < 0x03050200
|
||||
extern "C" {
|
||||
struct _Py_atomic_address { void *value; };
|
||||
PyAPI_DATA(_Py_atomic_address) _PyThreadState_Current;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code
|
||||
#define PYBIND11_STRINGIFY(x) #x
|
||||
#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x)
|
||||
#define PYBIND11_CONCAT(first, second) first##second
|
||||
|
||||
#define PYBIND11_CHECK_PYTHON_VERSION \
|
||||
{ \
|
||||
const char *compiled_ver = PYBIND11_TOSTRING(PY_MAJOR_VERSION) \
|
||||
"." PYBIND11_TOSTRING(PY_MINOR_VERSION); \
|
||||
const char *runtime_ver = Py_GetVersion(); \
|
||||
size_t len = std::strlen(compiled_ver); \
|
||||
if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \
|
||||
|| (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \
|
||||
PyErr_Format(PyExc_ImportError, \
|
||||
"Python version mismatch: module was compiled for Python %s, " \
|
||||
"but the interpreter version is incompatible: %s.", \
|
||||
compiled_ver, runtime_ver); \
|
||||
return nullptr; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define PYBIND11_CATCH_INIT_EXCEPTIONS \
|
||||
catch (pybind11::error_already_set &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} catch (const std::exception &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} \
|
||||
|
||||
/** \rst
|
||||
***Deprecated in favor of PYBIND11_MODULE***
|
||||
|
||||
This macro creates the entry point that will be invoked when the Python interpreter
|
||||
imports a plugin library. Please create a `module` in the function body and return
|
||||
the pointer to its underlying Python object at the end.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
PYBIND11_PLUGIN(example) {
|
||||
pybind11::module m("example", "pybind11 example plugin");
|
||||
/// Set up bindings here
|
||||
return m.ptr();
|
||||
}
|
||||
\endrst */
|
||||
#define PYBIND11_PLUGIN(name) \
|
||||
PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \
|
||||
static PyObject *pybind11_init(); \
|
||||
PYBIND11_PLUGIN_IMPL(name) { \
|
||||
PYBIND11_CHECK_PYTHON_VERSION \
|
||||
try { \
|
||||
return pybind11_init(); \
|
||||
} PYBIND11_CATCH_INIT_EXCEPTIONS \
|
||||
} \
|
||||
PyObject *pybind11_init()
|
||||
|
||||
/** \rst
|
||||
This macro creates the entry point that will be invoked when the Python interpreter
|
||||
imports an extension module. The module name is given as the fist argument and it
|
||||
should not be in quotes. The second macro argument defines a variable of type
|
||||
`py::module` which can be used to initialize the module.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
PYBIND11_MODULE(example, m) {
|
||||
m.doc() = "pybind11 example module";
|
||||
|
||||
// Add bindings here
|
||||
m.def("foo", []() {
|
||||
return "Hello, World!";
|
||||
});
|
||||
}
|
||||
\endrst */
|
||||
#define PYBIND11_MODULE(name, variable) \
|
||||
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
|
||||
PYBIND11_PLUGIN_IMPL(name) { \
|
||||
PYBIND11_CHECK_PYTHON_VERSION \
|
||||
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
|
||||
try { \
|
||||
PYBIND11_CONCAT(pybind11_init_, name)(m); \
|
||||
return m.ptr(); \
|
||||
} PYBIND11_CATCH_INIT_EXCEPTIONS \
|
||||
} \
|
||||
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
|
||||
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
using ssize_t = Py_ssize_t;
|
||||
using size_t = std::size_t;
|
||||
|
||||
/// Approach used to cast a previously unknown C++ instance into a Python object
|
||||
enum class return_value_policy : uint8_t {
|
||||
/** This is the default return value policy, which falls back to the policy
|
||||
return_value_policy::take_ownership when the return value is a pointer.
|
||||
Otherwise, it uses return_value::move or return_value::copy for rvalue
|
||||
and lvalue references, respectively. See below for a description of what
|
||||
all of these different policies do. */
|
||||
automatic = 0,
|
||||
|
||||
/** As above, but use policy return_value_policy::reference when the return
|
||||
value is a pointer. This is the default conversion policy for function
|
||||
arguments when calling Python functions manually from C++ code (i.e. via
|
||||
handle::operator()). You probably won't need to use this. */
|
||||
automatic_reference,
|
||||
|
||||
/** Reference an existing object (i.e. do not create a new copy) and take
|
||||
ownership. Python will call the destructor and delete operator when the
|
||||
object’s reference count reaches zero. Undefined behavior ensues when
|
||||
the C++ side does the same.. */
|
||||
take_ownership,
|
||||
|
||||
/** Create a new copy of the returned object, which will be owned by
|
||||
Python. This policy is comparably safe because the lifetimes of the two
|
||||
instances are decoupled. */
|
||||
copy,
|
||||
|
||||
/** Use std::move to move the return value contents into a new instance
|
||||
that will be owned by Python. This policy is comparably safe because the
|
||||
lifetimes of the two instances (move source and destination) are
|
||||
decoupled. */
|
||||
move,
|
||||
|
||||
/** Reference an existing object, but do not take ownership. The C++ side
|
||||
is responsible for managing the object’s lifetime and deallocating it
|
||||
when it is no longer used. Warning: undefined behavior will ensue when
|
||||
the C++ side deletes an object that is still referenced and used by
|
||||
Python. */
|
||||
reference,
|
||||
|
||||
/** This policy only applies to methods and properties. It references the
|
||||
object without taking ownership similar to the above
|
||||
return_value_policy::reference policy. In contrast to that policy, the
|
||||
function or property’s implicit this argument (called the parent) is
|
||||
considered to be the the owner of the return value (the child).
|
||||
pybind11 then couples the lifetime of the parent to the child via a
|
||||
reference relationship that ensures that the parent cannot be garbage
|
||||
collected while Python is still using the child. More advanced
|
||||
variations of this scheme are also possible using combinations of
|
||||
return_value_policy::reference and the keep_alive call policy */
|
||||
reference_internal
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
inline static constexpr int log2(size_t n, int k = 0) { return (n <= 1) ? k : log2(n >> 1, k + 1); }
|
||||
|
||||
// Returns the size as a multiple of sizeof(void *), rounded up.
|
||||
inline static constexpr size_t size_in_ptrs(size_t s) { return 1 + ((s - 1) >> log2(sizeof(void *))); }
|
||||
|
||||
/**
|
||||
* The space to allocate for simple layout instance holders (see below) in multiple of the size of
|
||||
* a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required
|
||||
* to holder either a std::unique_ptr or std::shared_ptr (which is almost always
|
||||
* sizeof(std::shared_ptr<T>)).
|
||||
*/
|
||||
constexpr size_t instance_simple_holder_in_ptrs() {
|
||||
static_assert(sizeof(std::shared_ptr<int>) >= sizeof(std::unique_ptr<int>),
|
||||
"pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs");
|
||||
return size_in_ptrs(sizeof(std::shared_ptr<int>));
|
||||
}
|
||||
|
||||
// Forward declarations
|
||||
struct type_info;
|
||||
struct value_and_holder;
|
||||
|
||||
struct nonsimple_values_and_holders {
|
||||
void **values_and_holders;
|
||||
uint8_t *status;
|
||||
};
|
||||
|
||||
/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof')
|
||||
struct instance {
|
||||
PyObject_HEAD
|
||||
/// Storage for pointers and holder; see simple_layout, below, for a description
|
||||
union {
|
||||
void *simple_value_holder[1 + instance_simple_holder_in_ptrs()];
|
||||
nonsimple_values_and_holders nonsimple;
|
||||
};
|
||||
/// Weak references
|
||||
PyObject *weakrefs;
|
||||
/// If true, the pointer is owned which means we're free to manage it with a holder.
|
||||
bool owned : 1;
|
||||
/**
|
||||
* An instance has two possible value/holder layouts.
|
||||
*
|
||||
* Simple layout (when this flag is true), means the `simple_value_holder` is set with a pointer
|
||||
* and the holder object governing that pointer, i.e. [val1*][holder]. This layout is applied
|
||||
* whenever there is no python-side multiple inheritance of bound C++ types *and* the type's
|
||||
* holder will fit in the default space (which is large enough to hold either a std::unique_ptr
|
||||
* or std::shared_ptr).
|
||||
*
|
||||
* Non-simple layout applies when using custom holders that require more space than `shared_ptr`
|
||||
* (which is typically the size of two pointers), or when multiple inheritance is used on the
|
||||
* python side. Non-simple layout allocates the required amount of memory to have multiple
|
||||
* bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a
|
||||
* pointer to allocated space of the required space to hold a sequence of value pointers and
|
||||
* holders followed `status`, a set of bit flags (1 byte each), i.e.
|
||||
* [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of
|
||||
* `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the
|
||||
* beginning of the [bb...] block (but not independently allocated).
|
||||
*
|
||||
* Status bits indicate whether the associated holder is constructed (&
|
||||
* status_holder_constructed) and whether the value pointer is registered (&
|
||||
* status_instance_registered) in `registered_instances`.
|
||||
*/
|
||||
bool simple_layout : 1;
|
||||
/// For simple layout, tracks whether the holder has been constructed
|
||||
bool simple_holder_constructed : 1;
|
||||
/// For simple layout, tracks whether the instance is registered in `registered_instances`
|
||||
bool simple_instance_registered : 1;
|
||||
/// If true, get_internals().patients has an entry for this object
|
||||
bool has_patients : 1;
|
||||
|
||||
/// Initializes all of the above type/values/holders data (but not the instance values themselves)
|
||||
void allocate_layout();
|
||||
|
||||
/// Destroys/deallocates all of the above
|
||||
void deallocate_layout();
|
||||
|
||||
/// Returns the value_and_holder wrapper for the given type (or the first, if `find_type`
|
||||
/// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if
|
||||
/// `throw_if_missing` is false.
|
||||
value_and_holder get_value_and_holder(const type_info *find_type = nullptr, bool throw_if_missing = true);
|
||||
|
||||
/// Bit values for the non-simple status flags
|
||||
static constexpr uint8_t status_holder_constructed = 1;
|
||||
static constexpr uint8_t status_instance_registered = 2;
|
||||
};
|
||||
|
||||
static_assert(std::is_standard_layout<instance>::value, "Internal error: `pybind11::detail::instance` is not standard layout!");
|
||||
|
||||
/// from __cpp_future__ import (convenient aliases from C++14/17)
|
||||
#if defined(PYBIND11_CPP14) && (!defined(_MSC_VER) || _MSC_VER >= 1910)
|
||||
using std::enable_if_t;
|
||||
using std::conditional_t;
|
||||
using std::remove_cv_t;
|
||||
using std::remove_reference_t;
|
||||
#else
|
||||
template <bool B, typename T = void> using enable_if_t = typename std::enable_if<B, T>::type;
|
||||
template <bool B, typename T, typename F> using conditional_t = typename std::conditional<B, T, F>::type;
|
||||
template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
template <typename T> using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
#endif
|
||||
|
||||
/// Index sequences
|
||||
#if defined(PYBIND11_CPP14)
|
||||
using std::index_sequence;
|
||||
using std::make_index_sequence;
|
||||
#else
|
||||
template<size_t ...> struct index_sequence { };
|
||||
template<size_t N, size_t ...S> struct make_index_sequence_impl : make_index_sequence_impl <N - 1, N - 1, S...> { };
|
||||
template<size_t ...S> struct make_index_sequence_impl <0, S...> { typedef index_sequence<S...> type; };
|
||||
template<size_t N> using make_index_sequence = typename make_index_sequence_impl<N>::type;
|
||||
#endif
|
||||
|
||||
/// Make an index sequence of the indices of true arguments
|
||||
template <typename ISeq, size_t, bool...> struct select_indices_impl { using type = ISeq; };
|
||||
template <size_t... IPrev, size_t I, bool B, bool... Bs> struct select_indices_impl<index_sequence<IPrev...>, I, B, Bs...>
|
||||
: select_indices_impl<conditional_t<B, index_sequence<IPrev..., I>, index_sequence<IPrev...>>, I + 1, Bs...> {};
|
||||
template <bool... Bs> using select_indices = typename select_indices_impl<index_sequence<>, 0, Bs...>::type;
|
||||
|
||||
/// Backports of std::bool_constant and std::negation to accommodate older compilers
|
||||
template <bool B> using bool_constant = std::integral_constant<bool, B>;
|
||||
template <typename T> struct negation : bool_constant<!T::value> { };
|
||||
|
||||
template <typename...> struct void_t_impl { using type = void; };
|
||||
template <typename... Ts> using void_t = typename void_t_impl<Ts...>::type;
|
||||
|
||||
/// Compile-time all/any/none of that check the boolean value of all template types
|
||||
#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916))
|
||||
template <class... Ts> using all_of = bool_constant<(Ts::value && ...)>;
|
||||
template <class... Ts> using any_of = bool_constant<(Ts::value || ...)>;
|
||||
#elif !defined(_MSC_VER)
|
||||
template <bool...> struct bools {};
|
||||
template <class... Ts> using all_of = std::is_same<
|
||||
bools<Ts::value..., true>,
|
||||
bools<true, Ts::value...>>;
|
||||
template <class... Ts> using any_of = negation<all_of<negation<Ts>...>>;
|
||||
#else
|
||||
// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit
|
||||
// at a slight loss of compilation efficiency).
|
||||
template <class... Ts> using all_of = std::conjunction<Ts...>;
|
||||
template <class... Ts> using any_of = std::disjunction<Ts...>;
|
||||
#endif
|
||||
template <class... Ts> using none_of = negation<any_of<Ts...>>;
|
||||
|
||||
template <class T, template<class> class... Predicates> using satisfies_all_of = all_of<Predicates<T>...>;
|
||||
template <class T, template<class> class... Predicates> using satisfies_any_of = any_of<Predicates<T>...>;
|
||||
template <class T, template<class> class... Predicates> using satisfies_none_of = none_of<Predicates<T>...>;
|
||||
|
||||
/// Strip the class from a method type
|
||||
template <typename T> struct remove_class { };
|
||||
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...)> { typedef R type(A...); };
|
||||
template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...) const> { typedef R type(A...); };
|
||||
|
||||
/// Helper template to strip away type modifiers
|
||||
template <typename T> struct intrinsic_type { typedef T type; };
|
||||
template <typename T> struct intrinsic_type<const T> { typedef typename intrinsic_type<T>::type type; };
|
||||
template <typename T> struct intrinsic_type<T*> { typedef typename intrinsic_type<T>::type type; };
|
||||
template <typename T> struct intrinsic_type<T&> { typedef typename intrinsic_type<T>::type type; };
|
||||
template <typename T> struct intrinsic_type<T&&> { typedef typename intrinsic_type<T>::type type; };
|
||||
template <typename T, size_t N> struct intrinsic_type<const T[N]> { typedef typename intrinsic_type<T>::type type; };
|
||||
template <typename T, size_t N> struct intrinsic_type<T[N]> { typedef typename intrinsic_type<T>::type type; };
|
||||
template <typename T> using intrinsic_t = typename intrinsic_type<T>::type;
|
||||
|
||||
/// Helper type to replace 'void' in some expressions
|
||||
struct void_type { };
|
||||
|
||||
/// Helper template which holds a list of types
|
||||
template <typename...> struct type_list { };
|
||||
|
||||
/// Compile-time integer sum
|
||||
#ifdef __cpp_fold_expressions
|
||||
template <typename... Ts> constexpr size_t constexpr_sum(Ts... ns) { return (0 + ... + size_t{ns}); }
|
||||
#else
|
||||
constexpr size_t constexpr_sum() { return 0; }
|
||||
template <typename T, typename... Ts>
|
||||
constexpr size_t constexpr_sum(T n, Ts... ns) { return size_t{n} + constexpr_sum(ns...); }
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(constexpr_impl)
|
||||
/// Implementation details for constexpr functions
|
||||
constexpr int first(int i) { return i; }
|
||||
template <typename T, typename... Ts>
|
||||
constexpr int first(int i, T v, Ts... vs) { return v ? i : first(i + 1, vs...); }
|
||||
|
||||
constexpr int last(int /*i*/, int result) { return result; }
|
||||
template <typename T, typename... Ts>
|
||||
constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); }
|
||||
NAMESPACE_END(constexpr_impl)
|
||||
|
||||
/// Return the index of the first type in Ts which satisfies Predicate<T>. Returns sizeof...(Ts) if
|
||||
/// none match.
|
||||
template <template<typename> class Predicate, typename... Ts>
|
||||
constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate<Ts>::value...); }
|
||||
|
||||
/// Return the index of the last type in Ts which satisfies Predicate<T>, or -1 if none match.
|
||||
template <template<typename> class Predicate, typename... Ts>
|
||||
constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate<Ts>::value...); }
|
||||
|
||||
/// Return the Nth element from the parameter pack
|
||||
template <size_t N, typename T, typename... Ts>
|
||||
struct pack_element { using type = typename pack_element<N - 1, Ts...>::type; };
|
||||
template <typename T, typename... Ts>
|
||||
struct pack_element<0, T, Ts...> { using type = T; };
|
||||
|
||||
/// Return the one and only type which matches the predicate, or Default if none match.
|
||||
/// If more than one type matches the predicate, fail at compile-time.
|
||||
template <template<typename> class Predicate, typename Default, typename... Ts>
|
||||
struct exactly_one {
|
||||
static constexpr auto found = constexpr_sum(Predicate<Ts>::value...);
|
||||
static_assert(found <= 1, "Found more than one type matching the predicate");
|
||||
|
||||
static constexpr auto index = found ? constexpr_first<Predicate, Ts...>() : 0;
|
||||
using type = conditional_t<found, typename pack_element<index, Ts...>::type, Default>;
|
||||
};
|
||||
template <template<typename> class P, typename Default>
|
||||
struct exactly_one<P, Default> { using type = Default; };
|
||||
|
||||
template <template<typename> class Predicate, typename Default, typename... Ts>
|
||||
using exactly_one_t = typename exactly_one<Predicate, Default, Ts...>::type;
|
||||
|
||||
/// Defer the evaluation of type T until types Us are instantiated
|
||||
template <typename T, typename... /*Us*/> struct deferred_type { using type = T; };
|
||||
template <typename T, typename... Us> using deferred_t = typename deferred_type<T, Us...>::type;
|
||||
|
||||
/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of<T, T>::value == false`,
|
||||
/// unlike `std::is_base_of`)
|
||||
template <typename Base, typename Derived> using is_strict_base_of = bool_constant<
|
||||
std::is_base_of<Base, Derived>::value && !std::is_same<Base, Derived>::value>;
|
||||
|
||||
/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer
|
||||
/// can be converted to a Base pointer)
|
||||
template <typename Base, typename Derived> using is_accessible_base_of = bool_constant<
|
||||
std::is_base_of<Base, Derived>::value && std::is_convertible<Derived *, Base *>::value>;
|
||||
|
||||
template <template<typename...> class Base>
|
||||
struct is_template_base_of_impl {
|
||||
template <typename... Us> static std::true_type check(Base<Us...> *);
|
||||
static std::false_type check(...);
|
||||
};
|
||||
|
||||
/// Check if a template is the base of a type. For example:
|
||||
/// `is_template_base_of<Base, T>` is true if `struct T : Base<U> {}` where U can be anything
|
||||
template <template<typename...> class Base, typename T>
|
||||
#if !defined(_MSC_VER)
|
||||
using is_template_base_of = decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr));
|
||||
#else // MSVC2015 has trouble with decltype in template aliases
|
||||
struct is_template_base_of : decltype(is_template_base_of_impl<Base>::check((intrinsic_t<T>*)nullptr)) { };
|
||||
#endif
|
||||
|
||||
/// Check if T is an instantiation of the template `Class`. For example:
|
||||
/// `is_instantiation<shared_ptr, T>` is true if `T == shared_ptr<U>` where U can be anything.
|
||||
template <template<typename...> class Class, typename T>
|
||||
struct is_instantiation : std::false_type { };
|
||||
template <template<typename...> class Class, typename... Us>
|
||||
struct is_instantiation<Class, Class<Us...>> : std::true_type { };
|
||||
|
||||
/// Check if T is std::shared_ptr<U> where U can be anything
|
||||
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
|
||||
|
||||
/// Check if T looks like an input iterator
|
||||
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
|
||||
template <typename T>
|
||||
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
|
||||
: std::true_type {};
|
||||
|
||||
template <typename T> using is_function_pointer = bool_constant<
|
||||
std::is_pointer<T>::value && std::is_function<typename std::remove_pointer<T>::type>::value>;
|
||||
|
||||
template <typename F> struct strip_function_object {
|
||||
using type = typename remove_class<decltype(&F::operator())>::type;
|
||||
};
|
||||
|
||||
// Extracts the function signature from a function, function pointer or lambda.
|
||||
template <typename Function, typename F = remove_reference_t<Function>>
|
||||
using function_signature_t = conditional_t<
|
||||
std::is_function<F>::value,
|
||||
F,
|
||||
typename conditional_t<
|
||||
std::is_pointer<F>::value || std::is_member_pointer<F>::value,
|
||||
std::remove_pointer<F>,
|
||||
strip_function_object<F>
|
||||
>::type
|
||||
>;
|
||||
|
||||
/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member
|
||||
/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used
|
||||
/// in a place where passing a lambda makes sense.
|
||||
template <typename T> using is_lambda = satisfies_none_of<remove_reference_t<T>,
|
||||
std::is_function, std::is_pointer, std::is_member_pointer>;
|
||||
|
||||
/// Ignore that a variable is unused in compiler warnings
|
||||
inline void ignore_unused(const int *) { }
|
||||
|
||||
/// Apply a function over each element of a parameter pack
|
||||
#ifdef __cpp_fold_expressions
|
||||
#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...)
|
||||
#else
|
||||
using expand_side_effects = bool[];
|
||||
#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) pybind11::detail::expand_side_effects{ ((PATTERN), void(), false)..., false }
|
||||
#endif
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// C++ bindings of builtin Python exceptions
|
||||
class builtin_exception : public std::runtime_error {
|
||||
public:
|
||||
using std::runtime_error::runtime_error;
|
||||
/// Set the error using the Python C API
|
||||
virtual void set_error() const = 0;
|
||||
};
|
||||
|
||||
#define PYBIND11_RUNTIME_EXCEPTION(name, type) \
|
||||
class name : public builtin_exception { public: \
|
||||
using builtin_exception::builtin_exception; \
|
||||
name() : name("") { } \
|
||||
void set_error() const override { PyErr_SetString(type, what()); } \
|
||||
};
|
||||
|
||||
PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration)
|
||||
PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError)
|
||||
PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error
|
||||
PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally
|
||||
|
||||
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); }
|
||||
[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); }
|
||||
|
||||
template <typename T, typename SFINAE = void> struct format_descriptor { };
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
// Returns the index of the given type in the type char array below, and in the list in numpy.h
|
||||
// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double;
|
||||
// complex float,double,long double. Note that the long double types only participate when long
|
||||
// double is actually longer than double (it isn't under MSVC).
|
||||
// NB: not only the string below but also complex.h and numpy.h rely on this order.
|
||||
template <typename T, typename SFINAE = void> struct is_fmt_numeric { static constexpr bool value = false; };
|
||||
template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>::value>> {
|
||||
static constexpr bool value = true;
|
||||
static constexpr int index = std::is_same<T, bool>::value ? 0 : 1 + (
|
||||
std::is_integral<T>::value ? detail::log2(sizeof(T))*2 + std::is_unsigned<T>::value : 8 + (
|
||||
std::is_same<T, double>::value ? 1 : std::is_same<T, long double>::value ? 2 : 0));
|
||||
};
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>> {
|
||||
static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric<T>::index];
|
||||
static constexpr const char value[2] = { c, '\0' };
|
||||
static std::string format() { return std::string(1, c); }
|
||||
};
|
||||
|
||||
#if !defined(PYBIND11_CPP17)
|
||||
|
||||
template <typename T> constexpr const char format_descriptor<
|
||||
T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];
|
||||
|
||||
#endif
|
||||
|
||||
/// RAII wrapper that temporarily clears any Python error state
|
||||
struct error_scope {
|
||||
PyObject *type, *value, *trace;
|
||||
error_scope() { PyErr_Fetch(&type, &value, &trace); }
|
||||
~error_scope() { PyErr_Restore(type, value, trace); }
|
||||
};
|
||||
|
||||
/// Dummy destructor wrapper that can be used to expose classes with a private destructor
|
||||
struct nodelete { template <typename T> void operator()(T*) { } };
|
||||
|
||||
// overload_cast requires variable templates: C++14
|
||||
#if defined(PYBIND11_CPP14)
|
||||
#define PYBIND11_OVERLOAD_CAST 1
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
template <typename... Args>
|
||||
struct overload_cast_impl {
|
||||
constexpr overload_cast_impl() {} // MSVC 2015 needs this
|
||||
|
||||
template <typename Return>
|
||||
constexpr auto operator()(Return (*pf)(Args...)) const noexcept
|
||||
-> decltype(pf) { return pf; }
|
||||
|
||||
template <typename Return, typename Class>
|
||||
constexpr auto operator()(Return (Class::*pmf)(Args...), std::false_type = {}) const noexcept
|
||||
-> decltype(pmf) { return pmf; }
|
||||
|
||||
template <typename Return, typename Class>
|
||||
constexpr auto operator()(Return (Class::*pmf)(Args...) const, std::true_type) const noexcept
|
||||
-> decltype(pmf) { return pmf; }
|
||||
};
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// Syntax sugar for resolving overloaded function pointers:
|
||||
/// - regular: static_cast<Return (Class::*)(Arg0, Arg1, Arg2)>(&Class::func)
|
||||
/// - sweet: overload_cast<Arg0, Arg1, Arg2>(&Class::func)
|
||||
template <typename... Args>
|
||||
static constexpr detail::overload_cast_impl<Args...> overload_cast = {};
|
||||
// MSVC 2015 only accepts this particular initialization syntax for this variable template.
|
||||
|
||||
/// Const member function selector for overload_cast
|
||||
/// - regular: static_cast<Return (Class::*)(Arg) const>(&Class::func)
|
||||
/// - sweet: overload_cast<Arg>(&Class::func, const_)
|
||||
static constexpr auto const_ = std::true_type{};
|
||||
|
||||
#else // no overload_cast: providing something that static_assert-fails:
|
||||
template <typename... Args> struct overload_cast {
|
||||
static_assert(detail::deferred_t<std::false_type, Args...>::value,
|
||||
"pybind11::overload_cast<...> requires compiling in C++14 mode");
|
||||
};
|
||||
#endif // overload_cast
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
|
||||
// any standard container (or C-style array) supporting std::begin/std::end, any singleton
|
||||
// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair.
|
||||
template <typename T>
|
||||
class any_container {
|
||||
std::vector<T> v;
|
||||
public:
|
||||
any_container() = default;
|
||||
|
||||
// Can construct from a pair of iterators
|
||||
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
|
||||
any_container(It first, It last) : v(first, last) { }
|
||||
|
||||
// Implicit conversion constructor from any arbitrary container type with values convertible to T
|
||||
template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
|
||||
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }
|
||||
|
||||
// initializer_list's aren't deducible, so don't get matched by the above template; we need this
|
||||
// to explicitly allow implicit conversion from one:
|
||||
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
|
||||
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }
|
||||
|
||||
// Avoid copying if given an rvalue vector of the correct type.
|
||||
any_container(std::vector<T> &&v) : v(std::move(v)) { }
|
||||
|
||||
// Moves the vector out of an rvalue any_container
|
||||
operator std::vector<T> &&() && { return std::move(v); }
|
||||
|
||||
// Dereferencing obtains a reference to the underlying vector
|
||||
std::vector<T> &operator*() { return v; }
|
||||
const std::vector<T> &operator*() const { return v; }
|
||||
|
||||
// -> lets you call methods on the underlying vector
|
||||
std::vector<T> *operator->() { return &v; }
|
||||
const std::vector<T> *operator->() const { return &v; }
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,100 +0,0 @@
|
|||
/*
|
||||
pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
# define PYBIND11_DESCR_CONSTEXPR static constexpr
|
||||
#else
|
||||
# define PYBIND11_DESCR_CONSTEXPR const
|
||||
#endif
|
||||
|
||||
/* Concatenate type signatures at compile time */
|
||||
template <size_t N, typename... Ts>
|
||||
struct descr {
|
||||
char text[N + 1];
|
||||
|
||||
constexpr descr() : text{'\0'} { }
|
||||
constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence<N>()) { }
|
||||
|
||||
template <size_t... Is>
|
||||
constexpr descr(char const (&s)[N+1], index_sequence<Is...>) : text{s[Is]..., '\0'} { }
|
||||
|
||||
template <typename... Chars>
|
||||
constexpr descr(char c, Chars... cs) : text{c, static_cast<char>(cs)..., '\0'} { }
|
||||
|
||||
static constexpr std::array<const std::type_info *, sizeof...(Ts) + 1> types() {
|
||||
return {{&typeid(Ts)..., nullptr}};
|
||||
}
|
||||
};
|
||||
|
||||
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2, size_t... Is1, size_t... Is2>
|
||||
constexpr descr<N1 + N2, Ts1..., Ts2...> plus_impl(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b,
|
||||
index_sequence<Is1...>, index_sequence<Is2...>) {
|
||||
return {a.text[Is1]..., b.text[Is2]...};
|
||||
}
|
||||
|
||||
template <size_t N1, size_t N2, typename... Ts1, typename... Ts2>
|
||||
constexpr descr<N1 + N2, Ts1..., Ts2...> operator+(const descr<N1, Ts1...> &a, const descr<N2, Ts2...> &b) {
|
||||
return plus_impl(a, b, make_index_sequence<N1>(), make_index_sequence<N2>());
|
||||
}
|
||||
|
||||
template <size_t N>
|
||||
constexpr descr<N - 1> _(char const(&text)[N]) { return descr<N - 1>(text); }
|
||||
constexpr descr<0> _(char const(&)[1]) { return {}; }
|
||||
|
||||
template <size_t Rem, size_t... Digits> struct int_to_str : int_to_str<Rem/10, Rem%10, Digits...> { };
|
||||
template <size_t...Digits> struct int_to_str<0, Digits...> {
|
||||
static constexpr auto digits = descr<sizeof...(Digits)>(('0' + Digits)...);
|
||||
};
|
||||
|
||||
// Ternary description (like std::conditional)
|
||||
template <bool B, size_t N1, size_t N2>
|
||||
constexpr enable_if_t<B, descr<N1 - 1>> _(char const(&text1)[N1], char const(&)[N2]) {
|
||||
return _(text1);
|
||||
}
|
||||
template <bool B, size_t N1, size_t N2>
|
||||
constexpr enable_if_t<!B, descr<N2 - 1>> _(char const(&)[N1], char const(&text2)[N2]) {
|
||||
return _(text2);
|
||||
}
|
||||
|
||||
template <bool B, typename T1, typename T2>
|
||||
constexpr enable_if_t<B, T1> _(const T1 &d, const T2 &) { return d; }
|
||||
template <bool B, typename T1, typename T2>
|
||||
constexpr enable_if_t<!B, T2> _(const T1 &, const T2 &d) { return d; }
|
||||
|
||||
template <size_t Size> auto constexpr _() -> decltype(int_to_str<Size / 10, Size % 10>::digits) {
|
||||
return int_to_str<Size / 10, Size % 10>::digits;
|
||||
}
|
||||
|
||||
template <typename Type> constexpr descr<1, Type> _() { return {'%'}; }
|
||||
|
||||
constexpr descr<0> concat() { return {}; }
|
||||
|
||||
template <size_t N, typename... Ts>
|
||||
constexpr descr<N, Ts...> concat(const descr<N, Ts...> &descr) { return descr; }
|
||||
|
||||
template <size_t N, typename... Ts, typename... Args>
|
||||
constexpr auto concat(const descr<N, Ts...> &d, const Args &...args)
|
||||
-> decltype(std::declval<descr<N + 2, Ts...>>() + concat(args...)) {
|
||||
return d + _(", ") + concat(args...);
|
||||
}
|
||||
|
||||
template <size_t N, typename... Ts>
|
||||
constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {
|
||||
return _("{") + descr + _("}");
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,335 +0,0 @@
|
|||
/*
|
||||
pybind11/detail/init.h: init factory function implementation and support code.
|
||||
|
||||
Copyright (c) 2017 Jason Rhinelander <jason@imaginary.ca>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "class.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <>
|
||||
class type_caster<value_and_holder> {
|
||||
public:
|
||||
bool load(handle h, bool) {
|
||||
value = reinterpret_cast<value_and_holder *>(h.ptr());
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename> using cast_op_type = value_and_holder &;
|
||||
operator value_and_holder &() { return *value; }
|
||||
static constexpr auto name = _<value_and_holder>();
|
||||
|
||||
private:
|
||||
value_and_holder *value = nullptr;
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(initimpl)
|
||||
|
||||
inline void no_nullptr(void *ptr) {
|
||||
if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr");
|
||||
}
|
||||
|
||||
// Implementing functions for all forms of py::init<...> and py::init(...)
|
||||
template <typename Class> using Cpp = typename Class::type;
|
||||
template <typename Class> using Alias = typename Class::type_alias;
|
||||
template <typename Class> using Holder = typename Class::holder_type;
|
||||
|
||||
template <typename Class> using is_alias_constructible = std::is_constructible<Alias<Class>, Cpp<Class> &&>;
|
||||
|
||||
// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance.
|
||||
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
|
||||
bool is_alias(Cpp<Class> *ptr) {
|
||||
return dynamic_cast<Alias<Class> *>(ptr) != nullptr;
|
||||
}
|
||||
// Failing fallback version of the above for a no-alias class (always returns false)
|
||||
template <typename /*Class*/>
|
||||
constexpr bool is_alias(void *) { return false; }
|
||||
|
||||
// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall
|
||||
// back to brace aggregate initiailization so that for aggregate initialization can be used with
|
||||
// py::init, e.g. `py::init<int, int>` to initialize a `struct T { int a; int b; }`. For
|
||||
// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually
|
||||
// works, but will not do the expected thing when `T` has an `initializer_list<T>` constructor).
|
||||
template <typename Class, typename... Args, detail::enable_if_t<std::is_constructible<Class, Args...>::value, int> = 0>
|
||||
inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward<Args>(args)...); }
|
||||
template <typename Class, typename... Args, detail::enable_if_t<!std::is_constructible<Class, Args...>::value, int> = 0>
|
||||
inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward<Args>(args)...}; }
|
||||
|
||||
// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with
|
||||
// an alias to provide only a single Cpp factory function as long as the Alias can be
|
||||
// constructed from an rvalue reference of the base Cpp type. This means that Alias classes
|
||||
// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to
|
||||
// inherit all the base class constructors.
|
||||
template <typename Class>
|
||||
void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/,
|
||||
value_and_holder &v_h, Cpp<Class> &&base) {
|
||||
v_h.value_ptr() = new Alias<Class>(std::move(base));
|
||||
}
|
||||
template <typename Class>
|
||||
[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/,
|
||||
value_and_holder &, Cpp<Class> &&) {
|
||||
throw type_error("pybind11::init(): unable to convert returned instance to required "
|
||||
"alias class: no `Alias<Class>(Class &&)` constructor available");
|
||||
}
|
||||
|
||||
// Error-generating fallback for factories that don't match one of the below construction
|
||||
// mechanisms.
|
||||
template <typename Class>
|
||||
void construct(...) {
|
||||
static_assert(!std::is_same<Class, Class>::value /* always false */,
|
||||
"pybind11::init(): init function must return a compatible pointer, "
|
||||
"holder, or value");
|
||||
}
|
||||
|
||||
// Pointer return v1: the factory function returns a class pointer for a registered class.
|
||||
// If we don't need an alias (because this class doesn't have one, or because the final type is
|
||||
// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to
|
||||
// construct an Alias from the returned base instance.
|
||||
template <typename Class>
|
||||
void construct(value_and_holder &v_h, Cpp<Class> *ptr, bool need_alias) {
|
||||
no_nullptr(ptr);
|
||||
if (Class::has_alias && need_alias && !is_alias<Class>(ptr)) {
|
||||
// We're going to try to construct an alias by moving the cpp type. Whether or not
|
||||
// that succeeds, we still need to destroy the original cpp pointer (either the
|
||||
// moved away leftover, if the alias construction works, or the value itself if we
|
||||
// throw an error), but we can't just call `delete ptr`: it might have a special
|
||||
// deleter, or might be shared_from_this. So we construct a holder around it as if
|
||||
// it was a normal instance, then steal the holder away into a local variable; thus
|
||||
// the holder and destruction happens when we leave the C++ scope, and the holder
|
||||
// class gets to handle the destruction however it likes.
|
||||
v_h.value_ptr() = ptr;
|
||||
v_h.set_instance_registered(true); // To prevent init_instance from registering it
|
||||
v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder
|
||||
Holder<Class> temp_holder(std::move(v_h.holder<Holder<Class>>())); // Steal the holder
|
||||
v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null
|
||||
v_h.set_instance_registered(false);
|
||||
|
||||
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(*ptr));
|
||||
} else {
|
||||
// Otherwise the type isn't inherited, so we don't need an Alias
|
||||
v_h.value_ptr() = ptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over
|
||||
// ownership of the pointer.
|
||||
template <typename Class, enable_if_t<Class::has_alias, int> = 0>
|
||||
void construct(value_and_holder &v_h, Alias<Class> *alias_ptr, bool) {
|
||||
no_nullptr(alias_ptr);
|
||||
v_h.value_ptr() = static_cast<Cpp<Class> *>(alias_ptr);
|
||||
}
|
||||
|
||||
// Holder return: copy its pointer, and move or copy the returned holder into the new instance's
|
||||
// holder. This also handles types like std::shared_ptr<T> and std::unique_ptr<T> where T is a
|
||||
// derived type (through those holder's implicit conversion from derived class holder constructors).
|
||||
template <typename Class>
|
||||
void construct(value_and_holder &v_h, Holder<Class> holder, bool need_alias) {
|
||||
auto *ptr = holder_helper<Holder<Class>>::get(holder);
|
||||
// If we need an alias, check that the held pointer is actually an alias instance
|
||||
if (Class::has_alias && need_alias && !is_alias<Class>(ptr))
|
||||
throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance "
|
||||
"is not an alias instance");
|
||||
|
||||
v_h.value_ptr() = ptr;
|
||||
v_h.type->init_instance(v_h.inst, &holder);
|
||||
}
|
||||
|
||||
// return-by-value version 1: returning a cpp class by value. If the class has an alias and an
|
||||
// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct
|
||||
// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't
|
||||
// need it, we simply move-construct the cpp value into a new instance.
|
||||
template <typename Class>
|
||||
void construct(value_and_holder &v_h, Cpp<Class> &&result, bool need_alias) {
|
||||
static_assert(std::is_move_constructible<Cpp<Class>>::value,
|
||||
"pybind11::init() return-by-value factory function requires a movable class");
|
||||
if (Class::has_alias && need_alias)
|
||||
construct_alias_from_cpp<Class>(is_alias_constructible<Class>{}, v_h, std::move(result));
|
||||
else
|
||||
v_h.value_ptr() = new Cpp<Class>(std::move(result));
|
||||
}
|
||||
|
||||
// return-by-value version 2: returning a value of the alias type itself. We move-construct an
|
||||
// Alias instance (even if no the python-side inheritance is involved). The is intended for
|
||||
// cases where Alias initialization is always desired.
|
||||
template <typename Class>
|
||||
void construct(value_and_holder &v_h, Alias<Class> &&result, bool) {
|
||||
static_assert(std::is_move_constructible<Alias<Class>>::value,
|
||||
"pybind11::init() return-by-alias-value factory function requires a movable alias class");
|
||||
v_h.value_ptr() = new Alias<Class>(std::move(result));
|
||||
}
|
||||
|
||||
// Implementing class for py::init<...>()
|
||||
template <typename... Args>
|
||||
struct constructor {
|
||||
template <typename Class, typename... Extra, enable_if_t<!Class::has_alias, int> = 0>
|
||||
static void execute(Class &cl, const Extra&... extra) {
|
||||
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias &&
|
||||
std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
|
||||
static void execute(Class &cl, const Extra&... extra) {
|
||||
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
|
||||
if (Py_TYPE(v_h.inst) == v_h.type->type)
|
||||
v_h.value_ptr() = construct_or_initialize<Cpp<Class>>(std::forward<Args>(args)...);
|
||||
else
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias &&
|
||||
!std::is_constructible<Cpp<Class>, Args...>::value, int> = 0>
|
||||
static void execute(Class &cl, const Extra&... extra) {
|
||||
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
// Implementing class for py::init_alias<...>()
|
||||
template <typename... Args> struct alias_constructor {
|
||||
template <typename Class, typename... Extra,
|
||||
enable_if_t<Class::has_alias && std::is_constructible<Alias<Class>, Args...>::value, int> = 0>
|
||||
static void execute(Class &cl, const Extra&... extra) {
|
||||
cl.def("__init__", [](value_and_holder &v_h, Args... args) {
|
||||
v_h.value_ptr() = construct_or_initialize<Alias<Class>>(std::forward<Args>(args)...);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
// Implementation class for py::init(Func) and py::init(Func, AliasFunc)
|
||||
template <typename CFunc, typename AFunc = void_type (*)(),
|
||||
typename = function_signature_t<CFunc>, typename = function_signature_t<AFunc>>
|
||||
struct factory;
|
||||
|
||||
// Specialization for py::init(Func)
|
||||
template <typename Func, typename Return, typename... Args>
|
||||
struct factory<Func, void_type (*)(), Return(Args...)> {
|
||||
remove_reference_t<Func> class_factory;
|
||||
|
||||
factory(Func &&f) : class_factory(std::forward<Func>(f)) { }
|
||||
|
||||
// The given class either has no alias or has no separate alias factory;
|
||||
// this always constructs the class itself. If the class is registered with an alias
|
||||
// type and an alias instance is needed (i.e. because the final type is a Python class
|
||||
// inheriting from the C++ type) the returned value needs to either already be an alias
|
||||
// instance, or the alias needs to be constructible from a `Class &&` argument.
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra &...extra) && {
|
||||
#if defined(PYBIND11_CPP14)
|
||||
cl.def("__init__", [func = std::move(class_factory)]
|
||||
#else
|
||||
auto &func = class_factory;
|
||||
cl.def("__init__", [func]
|
||||
#endif
|
||||
(value_and_holder &v_h, Args... args) {
|
||||
construct<Class>(v_h, func(std::forward<Args>(args)...),
|
||||
Py_TYPE(v_h.inst) != v_h.type->type);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization for py::init(Func, AliasFunc)
|
||||
template <typename CFunc, typename AFunc,
|
||||
typename CReturn, typename... CArgs, typename AReturn, typename... AArgs>
|
||||
struct factory<CFunc, AFunc, CReturn(CArgs...), AReturn(AArgs...)> {
|
||||
static_assert(sizeof...(CArgs) == sizeof...(AArgs),
|
||||
"pybind11::init(class_factory, alias_factory): class and alias factories "
|
||||
"must have identical argument signatures");
|
||||
static_assert(all_of<std::is_same<CArgs, AArgs>...>::value,
|
||||
"pybind11::init(class_factory, alias_factory): class and alias factories "
|
||||
"must have identical argument signatures");
|
||||
|
||||
remove_reference_t<CFunc> class_factory;
|
||||
remove_reference_t<AFunc> alias_factory;
|
||||
|
||||
factory(CFunc &&c, AFunc &&a)
|
||||
: class_factory(std::forward<CFunc>(c)), alias_factory(std::forward<AFunc>(a)) { }
|
||||
|
||||
// The class factory is called when the `self` type passed to `__init__` is the direct
|
||||
// class (i.e. not inherited), the alias factory when `self` is a Python-side subtype.
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra&... extra) && {
|
||||
static_assert(Class::has_alias, "The two-argument version of `py::init()` can "
|
||||
"only be used if the class has an alias");
|
||||
#if defined(PYBIND11_CPP14)
|
||||
cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)]
|
||||
#else
|
||||
auto &class_func = class_factory;
|
||||
auto &alias_func = alias_factory;
|
||||
cl.def("__init__", [class_func, alias_func]
|
||||
#endif
|
||||
(value_and_holder &v_h, CArgs... args) {
|
||||
if (Py_TYPE(v_h.inst) == v_h.type->type)
|
||||
// If the instance type equals the registered type we don't have inheritance, so
|
||||
// don't need the alias and can construct using the class function:
|
||||
construct<Class>(v_h, class_func(std::forward<CArgs>(args)...), false);
|
||||
else
|
||||
construct<Class>(v_h, alias_func(std::forward<CArgs>(args)...), true);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
/// Set just the C++ state. Same as `__init__`.
|
||||
template <typename Class, typename T>
|
||||
void setstate(value_and_holder &v_h, T &&result, bool need_alias) {
|
||||
construct<Class>(v_h, std::forward<T>(result), need_alias);
|
||||
}
|
||||
|
||||
/// Set both the C++ and Python states
|
||||
template <typename Class, typename T, typename O,
|
||||
enable_if_t<std::is_convertible<O, handle>::value, int> = 0>
|
||||
void setstate(value_and_holder &v_h, std::pair<T, O> &&result, bool need_alias) {
|
||||
construct<Class>(v_h, std::move(result.first), need_alias);
|
||||
setattr((PyObject *) v_h.inst, "__dict__", result.second);
|
||||
}
|
||||
|
||||
/// Implementation for py::pickle(GetState, SetState)
|
||||
template <typename Get, typename Set,
|
||||
typename = function_signature_t<Get>, typename = function_signature_t<Set>>
|
||||
struct pickle_factory;
|
||||
|
||||
template <typename Get, typename Set,
|
||||
typename RetState, typename Self, typename NewInstance, typename ArgState>
|
||||
struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
|
||||
static_assert(std::is_same<intrinsic_t<RetState>, intrinsic_t<ArgState>>::value,
|
||||
"The type returned by `__getstate__` must be the same "
|
||||
"as the argument accepted by `__setstate__`");
|
||||
|
||||
remove_reference_t<Get> get;
|
||||
remove_reference_t<Set> set;
|
||||
|
||||
pickle_factory(Get get, Set set)
|
||||
: get(std::forward<Get>(get)), set(std::forward<Set>(set)) { }
|
||||
|
||||
template <typename Class, typename... Extra>
|
||||
void execute(Class &cl, const Extra &...extra) && {
|
||||
cl.def("__getstate__", std::move(get));
|
||||
|
||||
#if defined(PYBIND11_CPP14)
|
||||
cl.def("__setstate__", [func = std::move(set)]
|
||||
#else
|
||||
auto &func = set;
|
||||
cl.def("__setstate__", [func]
|
||||
#endif
|
||||
(value_and_holder &v_h, ArgState state) {
|
||||
setstate<Class>(v_h, func(std::forward<ArgState>(state)),
|
||||
Py_TYPE(v_h.inst) != v_h.type->type);
|
||||
}, is_new_style_constructor(), extra...);
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(initimpl)
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(pybind11)
|
|
@ -1,291 +0,0 @@
|
|||
/*
|
||||
pybind11/detail/internals.h: Internal data structure and related functions
|
||||
|
||||
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../pytypes.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
// Forward declarations
|
||||
inline PyTypeObject *make_static_property_type();
|
||||
inline PyTypeObject *make_default_metaclass();
|
||||
inline PyObject *make_object_base_type(PyTypeObject *metaclass);
|
||||
|
||||
// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
|
||||
// Thread Specific Storage (TSS) API.
|
||||
#if PY_VERSION_HEX >= 0x03070000
|
||||
# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
|
||||
# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
|
||||
# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate))
|
||||
# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
|
||||
#else
|
||||
// Usually an int but a long on Cygwin64 with Python 3.x
|
||||
# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
|
||||
# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
|
||||
# if PY_MAJOR_VERSION < 3
|
||||
# define PYBIND11_TLS_DELETE_VALUE(key) \
|
||||
PyThread_delete_key_value(key)
|
||||
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
|
||||
do { \
|
||||
PyThread_delete_key_value((key)); \
|
||||
PyThread_set_key_value((key), (value)); \
|
||||
} while (false)
|
||||
# else
|
||||
# define PYBIND11_TLS_DELETE_VALUE(key) \
|
||||
PyThread_set_key_value((key), nullptr)
|
||||
# define PYBIND11_TLS_REPLACE_VALUE(key, value) \
|
||||
PyThread_set_key_value((key), (value))
|
||||
# endif
|
||||
#endif
|
||||
|
||||
// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
|
||||
// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
|
||||
// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
|
||||
// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
|
||||
// which works. If not under a known-good stl, provide our own name-based hash and equality
|
||||
// functions that use the type name.
|
||||
#if defined(__GLIBCXX__)
|
||||
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
|
||||
using type_hash = std::hash<std::type_index>;
|
||||
using type_equal_to = std::equal_to<std::type_index>;
|
||||
#else
|
||||
inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
|
||||
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
|
||||
}
|
||||
|
||||
struct type_hash {
|
||||
size_t operator()(const std::type_index &t) const {
|
||||
size_t hash = 5381;
|
||||
const char *ptr = t.name();
|
||||
while (auto c = static_cast<unsigned char>(*ptr++))
|
||||
hash = (hash * 33) ^ c;
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
struct type_equal_to {
|
||||
bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
|
||||
return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename value_type>
|
||||
using type_map = std::unordered_map<std::type_index, value_type, type_hash, type_equal_to>;
|
||||
|
||||
struct overload_hash {
|
||||
inline size_t operator()(const std::pair<const PyObject *, const char *>& v) const {
|
||||
size_t value = std::hash<const void *>()(v.first);
|
||||
value ^= std::hash<const void *>()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Internal data structure used to track registered instances and types.
|
||||
/// Whenever binary incompatible changes are made to this structure,
|
||||
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
|
||||
struct internals {
|
||||
type_map<type_info *> registered_types_cpp; // std::type_index -> pybind11's type information
|
||||
std::unordered_map<PyTypeObject *, std::vector<type_info *>> registered_types_py; // PyTypeObject* -> base type_info(s)
|
||||
std::unordered_multimap<const void *, instance*> registered_instances; // void * -> instance*
|
||||
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
|
||||
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
|
||||
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
|
||||
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
|
||||
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
|
||||
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
|
||||
std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str()
|
||||
PyTypeObject *static_property_type;
|
||||
PyTypeObject *default_metaclass;
|
||||
PyObject *instance_base;
|
||||
#if defined(WITH_THREAD)
|
||||
PYBIND11_TLS_KEY_INIT(tstate);
|
||||
PyInterpreterState *istate = nullptr;
|
||||
#endif
|
||||
};
|
||||
|
||||
/// Additional type information which does not fit into the PyTypeObject.
|
||||
/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
|
||||
struct type_info {
|
||||
PyTypeObject *type;
|
||||
const std::type_info *cpptype;
|
||||
size_t type_size, type_align, holder_size_in_ptrs;
|
||||
void *(*operator_new)(size_t);
|
||||
void (*init_instance)(instance *, const void *);
|
||||
void (*dealloc)(value_and_holder &v_h);
|
||||
std::vector<PyObject *(*)(PyObject *, PyTypeObject *)> implicit_conversions;
|
||||
std::vector<std::pair<const std::type_info *, void *(*)(void *)>> implicit_casts;
|
||||
std::vector<bool (*)(PyObject *, void *&)> *direct_conversions;
|
||||
buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
|
||||
void *get_buffer_data = nullptr;
|
||||
void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
|
||||
/* A simple type never occurs as a (direct or indirect) parent
|
||||
* of a class that makes use of multiple inheritance */
|
||||
bool simple_type : 1;
|
||||
/* True if there is no multiple inheritance in this type's inheritance tree */
|
||||
bool simple_ancestors : 1;
|
||||
/* for base vs derived holder_type checks */
|
||||
bool default_holder : 1;
|
||||
/* true if this is a type registered with py::module_local */
|
||||
bool module_local : 1;
|
||||
};
|
||||
|
||||
/// Tracks the `internals` and `type_info` ABI version independent of the main library version
|
||||
#define PYBIND11_INTERNALS_VERSION 3
|
||||
|
||||
#if defined(_DEBUG)
|
||||
# define PYBIND11_BUILD_TYPE "_debug"
|
||||
#else
|
||||
# define PYBIND11_BUILD_TYPE ""
|
||||
#endif
|
||||
|
||||
#if defined(WITH_THREAD)
|
||||
# define PYBIND11_INTERNALS_KIND ""
|
||||
#else
|
||||
# define PYBIND11_INTERNALS_KIND "_without_thread"
|
||||
#endif
|
||||
|
||||
#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
|
||||
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
|
||||
|
||||
#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
|
||||
PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
|
||||
|
||||
/// Each module locally stores a pointer to the `internals` data. The data
|
||||
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
|
||||
inline internals **&get_internals_pp() {
|
||||
static internals **internals_pp = nullptr;
|
||||
return internals_pp;
|
||||
}
|
||||
|
||||
/// Return a reference to the current `internals` data
|
||||
PYBIND11_NOINLINE inline internals &get_internals() {
|
||||
auto **&internals_pp = get_internals_pp();
|
||||
if (internals_pp && *internals_pp)
|
||||
return **internals_pp;
|
||||
|
||||
constexpr auto *id = PYBIND11_INTERNALS_ID;
|
||||
auto builtins = handle(PyEval_GetBuiltins());
|
||||
if (builtins.contains(id) && isinstance<capsule>(builtins[id])) {
|
||||
internals_pp = static_cast<internals **>(capsule(builtins[id]));
|
||||
|
||||
// We loaded builtins through python's builtins, which means that our `error_already_set`
|
||||
// and `builtin_exception` may be different local classes than the ones set up in the
|
||||
// initial exception translator, below, so add another for our local exception classes.
|
||||
//
|
||||
// libstdc++ doesn't require this (types there are identified only by name)
|
||||
#if !defined(__GLIBCXX__)
|
||||
(*internals_pp)->registered_exception_translators.push_front(
|
||||
[](std::exception_ptr p) -> void {
|
||||
try {
|
||||
if (p) std::rethrow_exception(p);
|
||||
} catch (error_already_set &e) { e.restore(); return;
|
||||
} catch (const builtin_exception &e) { e.set_error(); return;
|
||||
}
|
||||
}
|
||||
);
|
||||
#endif
|
||||
} else {
|
||||
if (!internals_pp) internals_pp = new internals*();
|
||||
auto *&internals_ptr = *internals_pp;
|
||||
internals_ptr = new internals();
|
||||
#if defined(WITH_THREAD)
|
||||
PyEval_InitThreads();
|
||||
PyThreadState *tstate = PyThreadState_Get();
|
||||
#if PY_VERSION_HEX >= 0x03070000
|
||||
internals_ptr->tstate = PyThread_tss_alloc();
|
||||
if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
|
||||
pybind11_fail("get_internals: could not successfully initialize the TSS key!");
|
||||
PyThread_tss_set(internals_ptr->tstate, tstate);
|
||||
#else
|
||||
internals_ptr->tstate = PyThread_create_key();
|
||||
if (internals_ptr->tstate == -1)
|
||||
pybind11_fail("get_internals: could not successfully initialize the TLS key!");
|
||||
PyThread_set_key_value(internals_ptr->tstate, tstate);
|
||||
#endif
|
||||
internals_ptr->istate = tstate->interp;
|
||||
#endif
|
||||
builtins[id] = capsule(internals_pp);
|
||||
internals_ptr->registered_exception_translators.push_front(
|
||||
[](std::exception_ptr p) -> void {
|
||||
try {
|
||||
if (p) std::rethrow_exception(p);
|
||||
} catch (error_already_set &e) { e.restore(); return;
|
||||
} catch (const builtin_exception &e) { e.set_error(); return;
|
||||
} catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
|
||||
} catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
|
||||
} catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
|
||||
} catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
|
||||
} catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
|
||||
} catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
|
||||
} catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
|
||||
} catch (...) {
|
||||
PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
|
||||
return;
|
||||
}
|
||||
}
|
||||
);
|
||||
internals_ptr->static_property_type = make_static_property_type();
|
||||
internals_ptr->default_metaclass = make_default_metaclass();
|
||||
internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
|
||||
}
|
||||
return **internals_pp;
|
||||
}
|
||||
|
||||
/// Works like `internals.registered_types_cpp`, but for module-local registered types:
|
||||
inline type_map<type_info *> ®istered_local_types_cpp() {
|
||||
static type_map<type_info *> locals{};
|
||||
return locals;
|
||||
}
|
||||
|
||||
/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
|
||||
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
|
||||
/// cleared when the program exits or after interpreter shutdown (when embedding), and so are
|
||||
/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
|
||||
template <typename... Args>
|
||||
const char *c_str(Args &&...args) {
|
||||
auto &strings = get_internals().static_strings;
|
||||
strings.emplace_front(std::forward<Args>(args)...);
|
||||
return strings.front().c_str();
|
||||
}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// Returns a named pointer that is shared among all extension modules (using the same
|
||||
/// pybind11 version) running in the current interpreter. Names starting with underscores
|
||||
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
|
||||
inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
|
||||
auto &internals = detail::get_internals();
|
||||
auto it = internals.shared_data.find(name);
|
||||
return it != internals.shared_data.end() ? it->second : nullptr;
|
||||
}
|
||||
|
||||
/// Set the shared data that can be later recovered by `get_shared_data()`.
|
||||
inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
|
||||
detail::get_internals().shared_data[name] = data;
|
||||
return data;
|
||||
}
|
||||
|
||||
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
|
||||
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
|
||||
/// added to the shared data under the given name and a reference to it is returned.
|
||||
template<typename T>
|
||||
T &get_or_create_shared_data(const std::string &name) {
|
||||
auto &internals = detail::get_internals();
|
||||
auto it = internals.shared_data.find(name);
|
||||
T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
|
||||
if (!ptr) {
|
||||
ptr = new T();
|
||||
internals.shared_data[name] = ptr;
|
||||
}
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,53 +0,0 @@
|
|||
/*
|
||||
pybind11/detail/typeid.h: Compiler-independent access to type identifiers
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
#if defined(__GNUG__)
|
||||
#include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
/// Erase all occurrences of a substring
|
||||
inline void erase_all(std::string &string, const std::string &search) {
|
||||
for (size_t pos = 0;;) {
|
||||
pos = string.find(search, pos);
|
||||
if (pos == std::string::npos) break;
|
||||
string.erase(pos, search.length());
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
|
||||
#if defined(__GNUG__)
|
||||
int status = 0;
|
||||
std::unique_ptr<char, void (*)(void *)> res {
|
||||
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
|
||||
if (status == 0)
|
||||
name = res.get();
|
||||
#else
|
||||
detail::erase_all(name, "class ");
|
||||
detail::erase_all(name, "struct ");
|
||||
detail::erase_all(name, "enum ");
|
||||
#endif
|
||||
detail::erase_all(name, "pybind11::");
|
||||
}
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// Return a string representation of a C++ type
|
||||
template <typename T> static std::string type_id() {
|
||||
std::string name(typeid(T).name());
|
||||
detail::clean_type_id(name);
|
||||
return name;
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,607 +0,0 @@
|
|||
/*
|
||||
pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "numpy.h"
|
||||
|
||||
#if defined(__INTEL_COMPILER)
|
||||
# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
|
||||
#elif defined(__GNUG__) || defined(__clang__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wconversion"
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
# ifdef __clang__
|
||||
// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated
|
||||
// under Clang, so disable that warning here:
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated"
|
||||
# endif
|
||||
# if __GNUC__ >= 7
|
||||
# pragma GCC diagnostic ignored "-Wint-in-bool-context"
|
||||
# endif
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
# pragma warning(push)
|
||||
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17
|
||||
#endif
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include <Eigen/SparseCore>
|
||||
|
||||
// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
|
||||
// move constructors that break things. We could detect this an explicitly copy, but an extra copy
|
||||
// of matrices seems highly undesirable.
|
||||
static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7");
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
|
||||
using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
|
||||
template <typename MatrixType> using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
|
||||
template <typename MatrixType> using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
#if EIGEN_VERSION_AT_LEAST(3,3,0)
|
||||
using EigenIndex = Eigen::Index;
|
||||
#else
|
||||
using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
|
||||
#endif
|
||||
|
||||
// Matches Eigen::Map, Eigen::Ref, blocks, etc:
|
||||
template <typename T> using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>, std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
|
||||
template <typename T> using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
|
||||
template <typename T> using is_eigen_dense_plain = all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
|
||||
template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
|
||||
// Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
|
||||
// basically covers anything that can be assigned to a dense matrix but that don't have a typical
|
||||
// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
|
||||
// SelfAdjointView fall into this category.
|
||||
template <typename T> using is_eigen_other = all_of<
|
||||
is_template_base_of<Eigen::EigenBase, T>,
|
||||
negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>
|
||||
>;
|
||||
|
||||
// Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
|
||||
template <bool EigenRowMajor> struct EigenConformable {
|
||||
bool conformable = false;
|
||||
EigenIndex rows = 0, cols = 0;
|
||||
EigenDStride stride{0, 0}; // Only valid if negativestrides is false!
|
||||
bool negativestrides = false; // If true, do not use stride!
|
||||
|
||||
EigenConformable(bool fits = false) : conformable{fits} {}
|
||||
// Matrix type:
|
||||
EigenConformable(EigenIndex r, EigenIndex c,
|
||||
EigenIndex rstride, EigenIndex cstride) :
|
||||
conformable{true}, rows{r}, cols{c} {
|
||||
// TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
|
||||
if (rstride < 0 || cstride < 0) {
|
||||
negativestrides = true;
|
||||
} else {
|
||||
stride = {EigenRowMajor ? rstride : cstride /* outer stride */,
|
||||
EigenRowMajor ? cstride : rstride /* inner stride */ };
|
||||
}
|
||||
}
|
||||
// Vector type:
|
||||
EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
|
||||
: EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
|
||||
|
||||
template <typename props> bool stride_compatible() const {
|
||||
// To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
|
||||
// matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
|
||||
return
|
||||
!negativestrides &&
|
||||
(props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
|
||||
(EigenRowMajor ? cols : rows) == 1) &&
|
||||
(props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
|
||||
(EigenRowMajor ? rows : cols) == 1);
|
||||
}
|
||||
operator bool() const { return conformable; }
|
||||
};
|
||||
|
||||
template <typename Type> struct eigen_extract_stride { using type = Type; };
|
||||
template <typename PlainObjectType, int MapOptions, typename StrideType>
|
||||
struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> { using type = StrideType; };
|
||||
template <typename PlainObjectType, int Options, typename StrideType>
|
||||
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
|
||||
|
||||
// Helper struct for extracting information from an Eigen type
|
||||
template <typename Type_> struct EigenProps {
|
||||
using Type = Type_;
|
||||
using Scalar = typename Type::Scalar;
|
||||
using StrideType = typename eigen_extract_stride<Type>::type;
|
||||
static constexpr EigenIndex
|
||||
rows = Type::RowsAtCompileTime,
|
||||
cols = Type::ColsAtCompileTime,
|
||||
size = Type::SizeAtCompileTime;
|
||||
static constexpr bool
|
||||
row_major = Type::IsRowMajor,
|
||||
vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
|
||||
fixed_rows = rows != Eigen::Dynamic,
|
||||
fixed_cols = cols != Eigen::Dynamic,
|
||||
fixed = size != Eigen::Dynamic, // Fully-fixed size
|
||||
dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
|
||||
|
||||
template <EigenIndex i, EigenIndex ifzero> using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
|
||||
static constexpr EigenIndex inner_stride = if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
|
||||
outer_stride = if_zero<StrideType::OuterStrideAtCompileTime,
|
||||
vector ? size : row_major ? cols : rows>::value;
|
||||
static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
|
||||
static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
|
||||
static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
|
||||
|
||||
// Takes an input array and determines whether we can make it fit into the Eigen type. If
|
||||
// the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
|
||||
// (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
|
||||
static EigenConformable<row_major> conformable(const array &a) {
|
||||
const auto dims = a.ndim();
|
||||
if (dims < 1 || dims > 2)
|
||||
return false;
|
||||
|
||||
if (dims == 2) { // Matrix type: require exact match (or dynamic)
|
||||
|
||||
EigenIndex
|
||||
np_rows = a.shape(0),
|
||||
np_cols = a.shape(1),
|
||||
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
|
||||
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
|
||||
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
|
||||
return false;
|
||||
|
||||
return {np_rows, np_cols, np_rstride, np_cstride};
|
||||
}
|
||||
|
||||
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
|
||||
// is used, we want the (single) numpy stride value.
|
||||
const EigenIndex n = a.shape(0),
|
||||
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
|
||||
|
||||
if (vector) { // Eigen type is a compile-time vector
|
||||
if (fixed && size != n)
|
||||
return false; // Vector size mismatch
|
||||
return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
|
||||
}
|
||||
else if (fixed) {
|
||||
// The type has a fixed size, but is not a vector: abort
|
||||
return false;
|
||||
}
|
||||
else if (fixed_cols) {
|
||||
// Since this isn't a vector, cols must be != 1. We allow this only if it exactly
|
||||
// equals the number of elements (rows is Dynamic, and so 1 row is allowed).
|
||||
if (cols != n) return false;
|
||||
return {1, n, stride};
|
||||
}
|
||||
else {
|
||||
// Otherwise it's either fully dynamic, or column dynamic; both become a column vector
|
||||
if (fixed_rows && rows != n) return false;
|
||||
return {n, 1, stride};
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
|
||||
static constexpr bool show_order = is_eigen_dense_map<Type>::value;
|
||||
static constexpr bool show_c_contiguous = show_order && requires_row_major;
|
||||
static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
|
||||
|
||||
static constexpr auto descriptor =
|
||||
_("numpy.ndarray[") + npy_format_descriptor<Scalar>::name +
|
||||
_("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
|
||||
_(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
|
||||
_("]") +
|
||||
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
|
||||
// satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
|
||||
// options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
|
||||
// to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
|
||||
// see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
|
||||
// *gave* a numpy.ndarray of the right type and dimensions.
|
||||
_<show_writeable>(", flags.writeable", "") +
|
||||
_<show_c_contiguous>(", flags.c_contiguous", "") +
|
||||
_<show_f_contiguous>(", flags.f_contiguous", "") +
|
||||
_("]");
|
||||
};
|
||||
|
||||
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
|
||||
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
|
||||
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
|
||||
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
|
||||
array a;
|
||||
if (props::vector)
|
||||
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
|
||||
else
|
||||
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
|
||||
src.data(), base);
|
||||
|
||||
if (!writeable)
|
||||
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
||||
|
||||
return a.release();
|
||||
}
|
||||
|
||||
// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
|
||||
// reference the Eigen object's data with `base` as the python-registered base class (if omitted,
|
||||
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
|
||||
// non-writeable if the given type is const.
|
||||
template <typename props, typename Type>
|
||||
handle eigen_ref_array(Type &src, handle parent = none()) {
|
||||
// none here is to get past array's should-we-copy detection, which currently always
|
||||
// copies when there is no base. Setting the base to None should be harmless.
|
||||
return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
|
||||
}
|
||||
|
||||
// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy
|
||||
// array that references the encapsulated data with a python-side reference to the capsule to tie
|
||||
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
|
||||
// not the Type of the pointer given is const.
|
||||
template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
|
||||
handle eigen_encapsulate(Type *src) {
|
||||
capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
|
||||
return eigen_ref_array<props>(*src, base);
|
||||
}
|
||||
|
||||
// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
|
||||
// types.
|
||||
template<typename Type>
|
||||
struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
|
||||
using Scalar = typename Type::Scalar;
|
||||
using props = EigenProps<Type>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
// If we're in no-convert mode, only load if given an array of the correct type
|
||||
if (!convert && !isinstance<array_t<Scalar>>(src))
|
||||
return false;
|
||||
|
||||
// Coerce into an array, but don't do type conversion yet; the copy below handles it.
|
||||
auto buf = array::ensure(src);
|
||||
|
||||
if (!buf)
|
||||
return false;
|
||||
|
||||
auto dims = buf.ndim();
|
||||
if (dims < 1 || dims > 2)
|
||||
return false;
|
||||
|
||||
auto fits = props::conformable(buf);
|
||||
if (!fits)
|
||||
return false;
|
||||
|
||||
// Allocate the new type, then build a numpy reference into it
|
||||
value = Type(fits.rows, fits.cols);
|
||||
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
|
||||
if (dims == 1) ref = ref.squeeze();
|
||||
else if (ref.ndim() == 1) buf = buf.squeeze();
|
||||
|
||||
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
|
||||
|
||||
if (result < 0) { // Copy failed!
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// Cast implementation
|
||||
template <typename CType>
|
||||
static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
|
||||
switch (policy) {
|
||||
case return_value_policy::take_ownership:
|
||||
case return_value_policy::automatic:
|
||||
return eigen_encapsulate<props>(src);
|
||||
case return_value_policy::move:
|
||||
return eigen_encapsulate<props>(new CType(std::move(*src)));
|
||||
case return_value_policy::copy:
|
||||
return eigen_array_cast<props>(*src);
|
||||
case return_value_policy::reference:
|
||||
case return_value_policy::automatic_reference:
|
||||
return eigen_ref_array<props>(*src);
|
||||
case return_value_policy::reference_internal:
|
||||
return eigen_ref_array<props>(*src, parent);
|
||||
default:
|
||||
throw cast_error("unhandled return_value_policy: should not happen!");
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
// Normal returned non-reference, non-const value:
|
||||
static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
|
||||
return cast_impl(&src, return_value_policy::move, parent);
|
||||
}
|
||||
// If you return a non-reference const, we mark the numpy array readonly:
|
||||
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
|
||||
return cast_impl(&src, return_value_policy::move, parent);
|
||||
}
|
||||
// lvalue reference return; default (automatic) becomes copy
|
||||
static handle cast(Type &src, return_value_policy policy, handle parent) {
|
||||
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
|
||||
policy = return_value_policy::copy;
|
||||
return cast_impl(&src, policy, parent);
|
||||
}
|
||||
// const lvalue reference return; default (automatic) becomes copy
|
||||
static handle cast(const Type &src, return_value_policy policy, handle parent) {
|
||||
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
|
||||
policy = return_value_policy::copy;
|
||||
return cast(&src, policy, parent);
|
||||
}
|
||||
// non-const pointer return
|
||||
static handle cast(Type *src, return_value_policy policy, handle parent) {
|
||||
return cast_impl(src, policy, parent);
|
||||
}
|
||||
// const pointer return
|
||||
static handle cast(const Type *src, return_value_policy policy, handle parent) {
|
||||
return cast_impl(src, policy, parent);
|
||||
}
|
||||
|
||||
static constexpr auto name = props::descriptor;
|
||||
|
||||
operator Type*() { return &value; }
|
||||
operator Type&() { return value; }
|
||||
operator Type&&() && { return std::move(value); }
|
||||
template <typename T> using cast_op_type = movable_cast_op_type<T>;
|
||||
|
||||
private:
|
||||
Type value;
|
||||
};
|
||||
|
||||
// Base class for casting reference/map/block/etc. objects back to python.
|
||||
template <typename MapType> struct eigen_map_caster {
|
||||
private:
|
||||
using props = EigenProps<MapType>;
|
||||
|
||||
public:
|
||||
|
||||
// Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
|
||||
// to stay around), but we'll allow it under the assumption that you know what you're doing (and
|
||||
// have an appropriate keep_alive in place). We return a numpy array pointing directly at the
|
||||
// ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note
|
||||
// that this means you need to ensure you don't destroy the object in some other way (e.g. with
|
||||
// an appropriate keep_alive, or with a reference to a statically allocated matrix).
|
||||
static handle cast(const MapType &src, return_value_policy policy, handle parent) {
|
||||
switch (policy) {
|
||||
case return_value_policy::copy:
|
||||
return eigen_array_cast<props>(src);
|
||||
case return_value_policy::reference_internal:
|
||||
return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
|
||||
case return_value_policy::reference:
|
||||
case return_value_policy::automatic:
|
||||
case return_value_policy::automatic_reference:
|
||||
return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
|
||||
default:
|
||||
// move, take_ownership don't make any sense for a ref/map:
|
||||
pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto name = props::descriptor;
|
||||
|
||||
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
|
||||
// types but not bound arguments). We still provide them (with an explicitly delete) so that
|
||||
// you end up here if you try anyway.
|
||||
bool load(handle, bool) = delete;
|
||||
operator MapType() = delete;
|
||||
template <typename> using cast_op_type = MapType;
|
||||
};
|
||||
|
||||
// We can return any map-like object (but can only load Refs, specialized next):
|
||||
template <typename Type> struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>>
|
||||
: eigen_map_caster<Type> {};
|
||||
|
||||
// Loader for Ref<...> arguments. See the documentation for info on how to make this work without
|
||||
// copying (it requires some extra effort in many cases).
|
||||
template <typename PlainObjectType, typename StrideType>
|
||||
struct type_caster<
|
||||
Eigen::Ref<PlainObjectType, 0, StrideType>,
|
||||
enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>
|
||||
> : public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
|
||||
private:
|
||||
using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
|
||||
using props = EigenProps<Type>;
|
||||
using Scalar = typename props::Scalar;
|
||||
using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
|
||||
using Array = array_t<Scalar, array::forcecast |
|
||||
((props::row_major ? props::inner_stride : props::outer_stride) == 1 ? array::c_style :
|
||||
(props::row_major ? props::outer_stride : props::inner_stride) == 1 ? array::f_style : 0)>;
|
||||
static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
|
||||
// Delay construction (these have no default constructor)
|
||||
std::unique_ptr<MapType> map;
|
||||
std::unique_ptr<Type> ref;
|
||||
// Our array. When possible, this is just a numpy array pointing to the source data, but
|
||||
// sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible
|
||||
// layout, or is an array of a type that needs to be converted). Using a numpy temporary
|
||||
// (rather than an Eigen temporary) saves an extra copy when we need both type conversion and
|
||||
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
|
||||
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
|
||||
Array copy_or_ref;
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
// First check whether what we have is already an array of the right type. If not, we can't
|
||||
// avoid a copy (because the copy is also going to do type conversion).
|
||||
bool need_copy = !isinstance<Array>(src);
|
||||
|
||||
EigenConformable<props::row_major> fits;
|
||||
if (!need_copy) {
|
||||
// We don't need a converting copy, but we also need to check whether the strides are
|
||||
// compatible with the Ref's stride requirements
|
||||
Array aref = reinterpret_borrow<Array>(src);
|
||||
|
||||
if (aref && (!need_writeable || aref.writeable())) {
|
||||
fits = props::conformable(aref);
|
||||
if (!fits) return false; // Incompatible dimensions
|
||||
if (!fits.template stride_compatible<props>())
|
||||
need_copy = true;
|
||||
else
|
||||
copy_or_ref = std::move(aref);
|
||||
}
|
||||
else {
|
||||
need_copy = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (need_copy) {
|
||||
// We need to copy: If we need a mutable reference, or we're not supposed to convert
|
||||
// (either because we're in the no-convert overload pass, or because we're explicitly
|
||||
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
|
||||
if (!convert || need_writeable) return false;
|
||||
|
||||
Array copy = Array::ensure(src);
|
||||
if (!copy) return false;
|
||||
fits = props::conformable(copy);
|
||||
if (!fits || !fits.template stride_compatible<props>())
|
||||
return false;
|
||||
copy_or_ref = std::move(copy);
|
||||
loader_life_support::add_patient(copy_or_ref);
|
||||
}
|
||||
|
||||
ref.reset();
|
||||
map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner())));
|
||||
ref.reset(new Type(*map));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
operator Type*() { return ref.get(); }
|
||||
operator Type&() { return *ref; }
|
||||
template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
|
||||
|
||||
private:
|
||||
template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
|
||||
Scalar *data(Array &a) { return a.mutable_data(); }
|
||||
|
||||
template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
|
||||
const Scalar *data(Array &a) { return a.data(); }
|
||||
|
||||
// Attempt to figure out a constructor of `Stride` that will work.
|
||||
// If both strides are fixed, use a default constructor:
|
||||
template <typename S> using stride_ctor_default = bool_constant<
|
||||
S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
|
||||
std::is_default_constructible<S>::value>;
|
||||
// Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
|
||||
// Eigen::Stride, and use it:
|
||||
template <typename S> using stride_ctor_dual = bool_constant<
|
||||
!stride_ctor_default<S>::value && std::is_constructible<S, EigenIndex, EigenIndex>::value>;
|
||||
// Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
|
||||
// it (passing whichever stride is dynamic).
|
||||
template <typename S> using stride_ctor_outer = bool_constant<
|
||||
!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
|
||||
S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic &&
|
||||
std::is_constructible<S, EigenIndex>::value>;
|
||||
template <typename S> using stride_ctor_inner = bool_constant<
|
||||
!any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
|
||||
S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
|
||||
std::is_constructible<S, EigenIndex>::value>;
|
||||
|
||||
template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
|
||||
static S make_stride(EigenIndex, EigenIndex) { return S(); }
|
||||
template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
|
||||
static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); }
|
||||
template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
|
||||
static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); }
|
||||
template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
|
||||
static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); }
|
||||
|
||||
};
|
||||
|
||||
// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
|
||||
// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
|
||||
// load() is not supported, but we can cast them into the python domain by first copying to a
|
||||
// regular Eigen::Matrix, then casting that.
|
||||
template <typename Type>
|
||||
struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
|
||||
protected:
|
||||
using Matrix = Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
|
||||
using props = EigenProps<Matrix>;
|
||||
public:
|
||||
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
handle h = eigen_encapsulate<props>(new Matrix(src));
|
||||
return h;
|
||||
}
|
||||
static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
|
||||
|
||||
static constexpr auto name = props::descriptor;
|
||||
|
||||
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
|
||||
// types but not bound arguments). We still provide them (with an explicitly delete) so that
|
||||
// you end up here if you try anyway.
|
||||
bool load(handle, bool) = delete;
|
||||
operator Type() = delete;
|
||||
template <typename> using cast_op_type = Type;
|
||||
};
|
||||
|
||||
template<typename Type>
|
||||
struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
|
||||
typedef typename Type::Scalar Scalar;
|
||||
typedef remove_reference_t<decltype(*std::declval<Type>().outerIndexPtr())> StorageIndex;
|
||||
typedef typename Type::Index Index;
|
||||
static constexpr bool rowMajor = Type::IsRowMajor;
|
||||
|
||||
bool load(handle src, bool) {
|
||||
if (!src)
|
||||
return false;
|
||||
|
||||
auto obj = reinterpret_borrow<object>(src);
|
||||
object sparse_module = module::import("scipy.sparse");
|
||||
object matrix_type = sparse_module.attr(
|
||||
rowMajor ? "csr_matrix" : "csc_matrix");
|
||||
|
||||
if (!obj.get_type().is(matrix_type)) {
|
||||
try {
|
||||
obj = matrix_type(obj);
|
||||
} catch (const error_already_set &) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto values = array_t<Scalar>((object) obj.attr("data"));
|
||||
auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
|
||||
auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
|
||||
auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
|
||||
auto nnz = obj.attr("nnz").cast<Index>();
|
||||
|
||||
if (!values || !innerIndices || !outerIndices)
|
||||
return false;
|
||||
|
||||
value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
|
||||
shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
|
||||
outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
|
||||
const_cast<Type&>(src).makeCompressed();
|
||||
|
||||
object matrix_type = module::import("scipy.sparse").attr(
|
||||
rowMajor ? "csr_matrix" : "csc_matrix");
|
||||
|
||||
array data(src.nonZeros(), src.valuePtr());
|
||||
array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
|
||||
array innerIndices(src.nonZeros(), src.innerIndexPtr());
|
||||
|
||||
return matrix_type(
|
||||
std::make_tuple(data, innerIndices, outerIndices),
|
||||
std::make_pair(src.rows(), src.cols())
|
||||
).release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
|
||||
+ npy_format_descriptor<Scalar>::name + _("]"));
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
||||
#if defined(__GNUG__) || defined(__clang__)
|
||||
# pragma GCC diagnostic pop
|
||||
#elif defined(_MSC_VER)
|
||||
# pragma warning(pop)
|
||||
#endif
|
|
@ -1,200 +0,0 @@
|
|||
/*
|
||||
pybind11/embed.h: Support for embedding the interpreter
|
||||
|
||||
Copyright (c) 2017 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include "eval.h"
|
||||
|
||||
#if defined(PYPY_VERSION)
|
||||
# error Embedding the interpreter is not supported with PyPy
|
||||
#endif
|
||||
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
extern "C" PyObject *pybind11_init_impl_##name() { \
|
||||
return pybind11_init_wrapper_##name(); \
|
||||
}
|
||||
#else
|
||||
# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
extern "C" void pybind11_init_impl_##name() { \
|
||||
pybind11_init_wrapper_##name(); \
|
||||
}
|
||||
#endif
|
||||
|
||||
/** \rst
|
||||
Add a new module to the table of builtins for the interpreter. Must be
|
||||
defined in global scope. The first macro parameter is the name of the
|
||||
module (without quotes). The second parameter is the variable which will
|
||||
be used as the interface to add functions and classes to the module.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
PYBIND11_EMBEDDED_MODULE(example, m) {
|
||||
// ... initialize functions and classes here
|
||||
m.def("foo", []() {
|
||||
return "Hello, World!";
|
||||
});
|
||||
}
|
||||
\endrst */
|
||||
#define PYBIND11_EMBEDDED_MODULE(name, variable) \
|
||||
static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
|
||||
static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
|
||||
auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
|
||||
try { \
|
||||
PYBIND11_CONCAT(pybind11_init_, name)(m); \
|
||||
return m.ptr(); \
|
||||
} catch (pybind11::error_already_set &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} catch (const std::exception &e) { \
|
||||
PyErr_SetString(PyExc_ImportError, e.what()); \
|
||||
return nullptr; \
|
||||
} \
|
||||
} \
|
||||
PYBIND11_EMBEDDED_MODULE_IMPL(name) \
|
||||
pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
|
||||
PYBIND11_CONCAT(pybind11_init_impl_, name)); \
|
||||
void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
|
||||
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
|
||||
struct embedded_module {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
using init_t = PyObject *(*)();
|
||||
#else
|
||||
using init_t = void (*)();
|
||||
#endif
|
||||
embedded_module(const char *name, init_t init) {
|
||||
if (Py_IsInitialized())
|
||||
pybind11_fail("Can't add new modules after the interpreter has been initialized");
|
||||
|
||||
auto result = PyImport_AppendInittab(name, init);
|
||||
if (result == -1)
|
||||
pybind11_fail("Insufficient memory to add a new module");
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/** \rst
|
||||
Initialize the Python interpreter. No other pybind11 or CPython API functions can be
|
||||
called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
|
||||
optional parameter can be used to skip the registration of signal handlers (see the
|
||||
`Python documentation`_ for details). Calling this function again after the interpreter
|
||||
has already been initialized is a fatal error.
|
||||
|
||||
If initializing the Python interpreter fails, then the program is terminated. (This
|
||||
is controlled by the CPython runtime and is an exception to pybind11's normal behavior
|
||||
of throwing exceptions on errors.)
|
||||
|
||||
.. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
|
||||
\endrst */
|
||||
inline void initialize_interpreter(bool init_signal_handlers = true) {
|
||||
if (Py_IsInitialized())
|
||||
pybind11_fail("The interpreter is already running");
|
||||
|
||||
Py_InitializeEx(init_signal_handlers ? 1 : 0);
|
||||
|
||||
// Make .py files in the working directory available by default
|
||||
module::import("sys").attr("path").cast<list>().append(".");
|
||||
}
|
||||
|
||||
/** \rst
|
||||
Shut down the Python interpreter. No pybind11 or CPython API functions can be called
|
||||
after this. In addition, pybind11 objects must not outlive the interpreter:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
{ // BAD
|
||||
py::initialize_interpreter();
|
||||
auto hello = py::str("Hello, World!");
|
||||
py::finalize_interpreter();
|
||||
} // <-- BOOM, hello's destructor is called after interpreter shutdown
|
||||
|
||||
{ // GOOD
|
||||
py::initialize_interpreter();
|
||||
{ // scoped
|
||||
auto hello = py::str("Hello, World!");
|
||||
} // <-- OK, hello is cleaned up properly
|
||||
py::finalize_interpreter();
|
||||
}
|
||||
|
||||
{ // BETTER
|
||||
py::scoped_interpreter guard{};
|
||||
auto hello = py::str("Hello, World!");
|
||||
}
|
||||
|
||||
.. warning::
|
||||
|
||||
The interpreter can be restarted by calling `initialize_interpreter` again.
|
||||
Modules created using pybind11 can be safely re-initialized. However, Python
|
||||
itself cannot completely unload binary extension modules and there are several
|
||||
caveats with regard to interpreter restarting. All the details can be found
|
||||
in the CPython documentation. In short, not all interpreter memory may be
|
||||
freed, either due to reference cycles or user-created global data.
|
||||
|
||||
\endrst */
|
||||
inline void finalize_interpreter() {
|
||||
handle builtins(PyEval_GetBuiltins());
|
||||
const char *id = PYBIND11_INTERNALS_ID;
|
||||
|
||||
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
|
||||
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
|
||||
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
|
||||
detail::internals **internals_ptr_ptr = detail::get_internals_pp();
|
||||
// It could also be stashed in builtins, so look there too:
|
||||
if (builtins.contains(id) && isinstance<capsule>(builtins[id]))
|
||||
internals_ptr_ptr = capsule(builtins[id]);
|
||||
|
||||
Py_Finalize();
|
||||
|
||||
if (internals_ptr_ptr) {
|
||||
delete *internals_ptr_ptr;
|
||||
*internals_ptr_ptr = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
/** \rst
|
||||
Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
|
||||
This a move-only guard and only a single instance can exist.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
#include <pybind11/embed.h>
|
||||
|
||||
int main() {
|
||||
py::scoped_interpreter guard{};
|
||||
py::print(Hello, World!);
|
||||
} // <-- interpreter shutdown
|
||||
\endrst */
|
||||
class scoped_interpreter {
|
||||
public:
|
||||
scoped_interpreter(bool init_signal_handlers = true) {
|
||||
initialize_interpreter(init_signal_handlers);
|
||||
}
|
||||
|
||||
scoped_interpreter(const scoped_interpreter &) = delete;
|
||||
scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
|
||||
scoped_interpreter &operator=(const scoped_interpreter &) = delete;
|
||||
scoped_interpreter &operator=(scoped_interpreter &&) = delete;
|
||||
|
||||
~scoped_interpreter() {
|
||||
if (is_valid)
|
||||
finalize_interpreter();
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_valid = true;
|
||||
};
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,117 +0,0 @@
|
|||
/*
|
||||
pybind11/exec.h: Support for evaluating Python expressions and statements
|
||||
from strings and files
|
||||
|
||||
Copyright (c) 2016 Klemens Morgenstern <klemens.morgenstern@ed-chemnitz.de> and
|
||||
Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
enum eval_mode {
|
||||
/// Evaluate a string containing an isolated expression
|
||||
eval_expr,
|
||||
|
||||
/// Evaluate a string containing a single statement. Returns \c none
|
||||
eval_single_statement,
|
||||
|
||||
/// Evaluate a string containing a sequence of statement. Returns \c none
|
||||
eval_statements
|
||||
};
|
||||
|
||||
template <eval_mode mode = eval_expr>
|
||||
object eval(str expr, object global = globals(), object local = object()) {
|
||||
if (!local)
|
||||
local = global;
|
||||
|
||||
/* PyRun_String does not accept a PyObject / encoding specifier,
|
||||
this seems to be the only alternative */
|
||||
std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
|
||||
|
||||
int start;
|
||||
switch (mode) {
|
||||
case eval_expr: start = Py_eval_input; break;
|
||||
case eval_single_statement: start = Py_single_input; break;
|
||||
case eval_statements: start = Py_file_input; break;
|
||||
default: pybind11_fail("invalid evaluation mode");
|
||||
}
|
||||
|
||||
PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
|
||||
if (!result)
|
||||
throw error_already_set();
|
||||
return reinterpret_steal<object>(result);
|
||||
}
|
||||
|
||||
template <eval_mode mode = eval_expr, size_t N>
|
||||
object eval(const char (&s)[N], object global = globals(), object local = object()) {
|
||||
/* Support raw string literals by removing common leading whitespace */
|
||||
auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
|
||||
: str(s);
|
||||
return eval<mode>(expr, global, local);
|
||||
}
|
||||
|
||||
inline void exec(str expr, object global = globals(), object local = object()) {
|
||||
eval<eval_statements>(expr, global, local);
|
||||
}
|
||||
|
||||
template <size_t N>
|
||||
void exec(const char (&s)[N], object global = globals(), object local = object()) {
|
||||
eval<eval_statements>(s, global, local);
|
||||
}
|
||||
|
||||
template <eval_mode mode = eval_statements>
|
||||
object eval_file(str fname, object global = globals(), object local = object()) {
|
||||
if (!local)
|
||||
local = global;
|
||||
|
||||
int start;
|
||||
switch (mode) {
|
||||
case eval_expr: start = Py_eval_input; break;
|
||||
case eval_single_statement: start = Py_single_input; break;
|
||||
case eval_statements: start = Py_file_input; break;
|
||||
default: pybind11_fail("invalid evaluation mode");
|
||||
}
|
||||
|
||||
int closeFile = 1;
|
||||
std::string fname_str = (std::string) fname;
|
||||
#if PY_VERSION_HEX >= 0x03040000
|
||||
FILE *f = _Py_fopen_obj(fname.ptr(), "r");
|
||||
#elif PY_VERSION_HEX >= 0x03000000
|
||||
FILE *f = _Py_fopen(fname.ptr(), "r");
|
||||
#else
|
||||
/* No unicode support in open() :( */
|
||||
auto fobj = reinterpret_steal<object>(PyFile_FromString(
|
||||
const_cast<char *>(fname_str.c_str()),
|
||||
const_cast<char*>("r")));
|
||||
FILE *f = nullptr;
|
||||
if (fobj)
|
||||
f = PyFile_AsFile(fobj.ptr());
|
||||
closeFile = 0;
|
||||
#endif
|
||||
if (!f) {
|
||||
PyErr_Clear();
|
||||
pybind11_fail("File \"" + fname_str + "\" could not be opened!");
|
||||
}
|
||||
|
||||
#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
|
||||
PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
|
||||
local.ptr());
|
||||
(void) closeFile;
|
||||
#else
|
||||
PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
|
||||
local.ptr(), closeFile);
|
||||
#endif
|
||||
|
||||
if (!result)
|
||||
throw error_already_set();
|
||||
return reinterpret_steal<object>(result);
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,83 +0,0 @@
|
|||
/*
|
||||
pybind11/functional.h: std::function<> support
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include <functional>
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename Return, typename... Args>
|
||||
struct type_caster<std::function<Return(Args...)>> {
|
||||
using type = std::function<Return(Args...)>;
|
||||
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
|
||||
using function_type = Return (*) (Args...);
|
||||
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
if (src.is_none()) {
|
||||
// Defer accepting None to other overloads (if we aren't in convert mode):
|
||||
if (!convert) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!isinstance<function>(src))
|
||||
return false;
|
||||
|
||||
auto func = reinterpret_borrow<function>(src);
|
||||
|
||||
/*
|
||||
When passing a C++ function as an argument to another C++
|
||||
function via Python, every function call would normally involve
|
||||
a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
|
||||
Here, we try to at least detect the case where the function is
|
||||
stateless (i.e. function pointer or lambda function without
|
||||
captured variables), in which case the roundtrip can be avoided.
|
||||
*/
|
||||
if (auto cfunc = func.cpp_function()) {
|
||||
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
|
||||
auto rec = (function_record *) c;
|
||||
|
||||
if (rec && rec->is_stateless &&
|
||||
same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
|
||||
struct capture { function_type f; };
|
||||
value = ((capture *) &rec->data)->f;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
value = [func](Args... args) -> Return {
|
||||
gil_scoped_acquire acq;
|
||||
object retval(func(std::forward<Args>(args)...));
|
||||
/* Visual studio 2015 parser issue: need parentheses around this expression */
|
||||
return (retval.template cast<Return>());
|
||||
};
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
|
||||
if (!f_)
|
||||
return none().inc_ref();
|
||||
|
||||
auto result = f_.template target<function_type>();
|
||||
if (result)
|
||||
return cpp_function(*result, policy).release();
|
||||
else
|
||||
return cpp_function(std::forward<Func>(f_), policy).release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster<Args>::name...) + _("], ")
|
||||
+ make_caster<retval_type>::name + _("]"));
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,200 +0,0 @@
|
|||
/*
|
||||
pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python
|
||||
|
||||
Copyright (c) 2017 Henry F. Schreiner
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
|
||||
#include <streambuf>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
// Buffer that writes to Python instead of C++
|
||||
class pythonbuf : public std::streambuf {
|
||||
private:
|
||||
using traits_type = std::streambuf::traits_type;
|
||||
|
||||
char d_buffer[1024];
|
||||
object pywrite;
|
||||
object pyflush;
|
||||
|
||||
int overflow(int c) {
|
||||
if (!traits_type::eq_int_type(c, traits_type::eof())) {
|
||||
*pptr() = traits_type::to_char_type(c);
|
||||
pbump(1);
|
||||
}
|
||||
return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
|
||||
}
|
||||
|
||||
int sync() {
|
||||
if (pbase() != pptr()) {
|
||||
// This subtraction cannot be negative, so dropping the sign
|
||||
str line(pbase(), static_cast<size_t>(pptr() - pbase()));
|
||||
|
||||
pywrite(line);
|
||||
pyflush();
|
||||
|
||||
setp(pbase(), epptr());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public:
|
||||
pythonbuf(object pyostream)
|
||||
: pywrite(pyostream.attr("write")),
|
||||
pyflush(pyostream.attr("flush")) {
|
||||
setp(d_buffer, d_buffer + sizeof(d_buffer) - 1);
|
||||
}
|
||||
|
||||
/// Sync before destroy
|
||||
~pythonbuf() {
|
||||
sync();
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
|
||||
/** \rst
|
||||
This a move-only guard that redirects output.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
#include <pybind11/iostream.h>
|
||||
|
||||
...
|
||||
|
||||
{
|
||||
py::scoped_ostream_redirect output;
|
||||
std::cout << "Hello, World!"; // Python stdout
|
||||
} // <-- return std::cout to normal
|
||||
|
||||
You can explicitly pass the c++ stream and the python object,
|
||||
for example to guard stderr instead.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
{
|
||||
py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")};
|
||||
std::cerr << "Hello, World!";
|
||||
}
|
||||
\endrst */
|
||||
class scoped_ostream_redirect {
|
||||
protected:
|
||||
std::streambuf *old;
|
||||
std::ostream &costream;
|
||||
detail::pythonbuf buffer;
|
||||
|
||||
public:
|
||||
scoped_ostream_redirect(
|
||||
std::ostream &costream = std::cout,
|
||||
object pyostream = module::import("sys").attr("stdout"))
|
||||
: costream(costream), buffer(pyostream) {
|
||||
old = costream.rdbuf(&buffer);
|
||||
}
|
||||
|
||||
~scoped_ostream_redirect() {
|
||||
costream.rdbuf(old);
|
||||
}
|
||||
|
||||
scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
|
||||
scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
|
||||
scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
|
||||
scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
|
||||
};
|
||||
|
||||
|
||||
/** \rst
|
||||
Like `scoped_ostream_redirect`, but redirects cerr by default. This class
|
||||
is provided primary to make ``py::call_guard`` easier to make.
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
m.def("noisy_func", &noisy_func,
|
||||
py::call_guard<scoped_ostream_redirect,
|
||||
scoped_estream_redirect>());
|
||||
|
||||
\endrst */
|
||||
class scoped_estream_redirect : public scoped_ostream_redirect {
|
||||
public:
|
||||
scoped_estream_redirect(
|
||||
std::ostream &costream = std::cerr,
|
||||
object pyostream = module::import("sys").attr("stderr"))
|
||||
: scoped_ostream_redirect(costream,pyostream) {}
|
||||
};
|
||||
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
// Class to redirect output as a context manager. C++ backend.
|
||||
class OstreamRedirect {
|
||||
bool do_stdout_;
|
||||
bool do_stderr_;
|
||||
std::unique_ptr<scoped_ostream_redirect> redirect_stdout;
|
||||
std::unique_ptr<scoped_estream_redirect> redirect_stderr;
|
||||
|
||||
public:
|
||||
OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
|
||||
: do_stdout_(do_stdout), do_stderr_(do_stderr) {}
|
||||
|
||||
void enter() {
|
||||
if (do_stdout_)
|
||||
redirect_stdout.reset(new scoped_ostream_redirect());
|
||||
if (do_stderr_)
|
||||
redirect_stderr.reset(new scoped_estream_redirect());
|
||||
}
|
||||
|
||||
void exit() {
|
||||
redirect_stdout.reset();
|
||||
redirect_stderr.reset();
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/** \rst
|
||||
This is a helper function to add a C++ redirect context manager to Python
|
||||
instead of using a C++ guard. To use it, add the following to your binding code:
|
||||
|
||||
.. code-block:: cpp
|
||||
|
||||
#include <pybind11/iostream.h>
|
||||
|
||||
...
|
||||
|
||||
py::add_ostream_redirect(m, "ostream_redirect");
|
||||
|
||||
You now have a Python context manager that redirects your output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with m.ostream_redirect():
|
||||
m.print_to_cout_function()
|
||||
|
||||
This manager can optionally be told which streams to operate on:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
with m.ostream_redirect(stdout=true, stderr=true):
|
||||
m.noisy_function_with_error_printing()
|
||||
|
||||
\endrst */
|
||||
inline class_<detail::OstreamRedirect> add_ostream_redirect(module m, std::string name = "ostream_redirect") {
|
||||
return class_<detail::OstreamRedirect>(m, name.c_str(), module_local())
|
||||
.def(init<bool,bool>(), arg("stdout")=true, arg("stderr")=true)
|
||||
.def("__enter__", &detail::OstreamRedirect::enter)
|
||||
.def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); });
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
File diff suppressed because it is too large
Load Diff
|
@ -1,168 +0,0 @@
|
|||
/*
|
||||
pybind11/operator.h: Metatemplates for operator overloading
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
|
||||
#if defined(__clang__) && !defined(__INTEL_COMPILER)
|
||||
# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type()))
|
||||
#elif defined(_MSC_VER)
|
||||
# pragma warning(push)
|
||||
# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/// Enumeration with all supported operator types
|
||||
enum op_id : int {
|
||||
op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift,
|
||||
op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert,
|
||||
op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
|
||||
op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
|
||||
op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
|
||||
op_repr, op_truediv, op_itruediv, op_hash
|
||||
};
|
||||
|
||||
enum op_type : int {
|
||||
op_l, /* base type on left */
|
||||
op_r, /* base type on right */
|
||||
op_u /* unary operator */
|
||||
};
|
||||
|
||||
struct self_t { };
|
||||
static const self_t self = self_t();
|
||||
|
||||
/// Type for an unused type slot
|
||||
struct undefined_t { };
|
||||
|
||||
/// Don't warn about an unused variable
|
||||
inline self_t __self() { return self; }
|
||||
|
||||
/// base template of operator implementations
|
||||
template <op_id, op_type, typename B, typename L, typename R> struct op_impl { };
|
||||
|
||||
/// Operator implementation generator
|
||||
template <op_id id, op_type ot, typename L, typename R> struct op_ {
|
||||
template <typename Class, typename... Extra> void execute(Class &cl, const Extra&... extra) const {
|
||||
using Base = typename Class::type;
|
||||
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
|
||||
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
|
||||
using op = op_impl<id, ot, Base, L_type, R_type>;
|
||||
cl.def(op::name(), &op::execute, is_operator(), extra...);
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (id == op_truediv || id == op_itruediv)
|
||||
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
|
||||
&op::execute, is_operator(), extra...);
|
||||
#endif
|
||||
}
|
||||
template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
|
||||
using Base = typename Class::type;
|
||||
using L_type = conditional_t<std::is_same<L, self_t>::value, Base, L>;
|
||||
using R_type = conditional_t<std::is_same<R, self_t>::value, Base, R>;
|
||||
using op = op_impl<id, ot, Base, L_type, R_type>;
|
||||
cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (id == op_truediv || id == op_itruediv)
|
||||
cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
|
||||
&op::execute, is_operator(), extra...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
|
||||
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
|
||||
static char const* name() { return "__" #id "__"; } \
|
||||
static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
|
||||
static B execute_cast(const L &l, const R &r) { return B(expr); } \
|
||||
}; \
|
||||
template <typename B, typename L, typename R> struct op_impl<op_##id, op_r, B, L, R> { \
|
||||
static char const* name() { return "__" #rid "__"; } \
|
||||
static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
|
||||
static B execute_cast(const R &r, const L &l) { return B(expr); } \
|
||||
}; \
|
||||
inline op_<op_##id, op_l, self_t, self_t> op(const self_t &, const self_t &) { \
|
||||
return op_<op_##id, op_l, self_t, self_t>(); \
|
||||
} \
|
||||
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
|
||||
return op_<op_##id, op_l, self_t, T>(); \
|
||||
} \
|
||||
template <typename T> op_<op_##id, op_r, T, self_t> op(const T &, const self_t &) { \
|
||||
return op_<op_##id, op_r, T, self_t>(); \
|
||||
}
|
||||
|
||||
#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
|
||||
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
|
||||
static char const* name() { return "__" #id "__"; } \
|
||||
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
|
||||
static B execute_cast(L &l, const R &r) { return B(expr); } \
|
||||
}; \
|
||||
template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &) { \
|
||||
return op_<op_##id, op_l, self_t, T>(); \
|
||||
}
|
||||
|
||||
#define PYBIND11_UNARY_OPERATOR(id, op, expr) \
|
||||
template <typename B, typename L> struct op_impl<op_##id, op_u, B, L, undefined_t> { \
|
||||
static char const* name() { return "__" #id "__"; } \
|
||||
static auto execute(const L &l) -> decltype(expr) { return expr; } \
|
||||
static B execute_cast(const L &l) { return B(expr); } \
|
||||
}; \
|
||||
inline op_<op_##id, op_u, self_t, undefined_t> op(const self_t &) { \
|
||||
return op_<op_##id, op_u, self_t, undefined_t>(); \
|
||||
}
|
||||
|
||||
PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
|
||||
PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
|
||||
PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
|
||||
PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
|
||||
PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
|
||||
PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
|
||||
PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r)
|
||||
PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
|
||||
PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
|
||||
PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
|
||||
PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
|
||||
PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
|
||||
PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r)
|
||||
PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
|
||||
PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r)
|
||||
PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
|
||||
//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r))
|
||||
PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
|
||||
PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
|
||||
PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
|
||||
PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
|
||||
PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
|
||||
PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
|
||||
PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
|
||||
PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
|
||||
PYBIND11_UNARY_OPERATOR(neg, operator-, -l)
|
||||
PYBIND11_UNARY_OPERATOR(pos, operator+, +l)
|
||||
PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
|
||||
PYBIND11_UNARY_OPERATOR(hash, hash, std::hash<L>()(l))
|
||||
PYBIND11_UNARY_OPERATOR(invert, operator~, (~l))
|
||||
PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
|
||||
PYBIND11_UNARY_OPERATOR(int, int_, (int) l)
|
||||
PYBIND11_UNARY_OPERATOR(float, float_, (double) l)
|
||||
|
||||
#undef PYBIND11_BINARY_OPERATOR
|
||||
#undef PYBIND11_INPLACE_OPERATOR
|
||||
#undef PYBIND11_UNARY_OPERATOR
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
using detail::self;
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
# pragma warning(pop)
|
||||
#endif
|
|
@ -1,65 +0,0 @@
|
|||
/*
|
||||
pybind11/options.h: global settings that are configurable at runtime.
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "detail/common.h"
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
|
||||
class options {
|
||||
public:
|
||||
|
||||
// Default RAII constructor, which leaves settings as they currently are.
|
||||
options() : previous_state(global_state()) {}
|
||||
|
||||
// Class is non-copyable.
|
||||
options(const options&) = delete;
|
||||
options& operator=(const options&) = delete;
|
||||
|
||||
// Destructor, which restores settings that were in effect before.
|
||||
~options() {
|
||||
global_state() = previous_state;
|
||||
}
|
||||
|
||||
// Setter methods (affect the global state):
|
||||
|
||||
options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; }
|
||||
|
||||
options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; }
|
||||
|
||||
options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; }
|
||||
|
||||
options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; }
|
||||
|
||||
// Getter methods (return the global state):
|
||||
|
||||
static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; }
|
||||
|
||||
static bool show_function_signatures() { return global_state().show_function_signatures; }
|
||||
|
||||
// This type is not meant to be allocated on the heap.
|
||||
void* operator new(size_t) = delete;
|
||||
|
||||
private:
|
||||
|
||||
struct state {
|
||||
bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings.
|
||||
bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings.
|
||||
};
|
||||
|
||||
static state &global_state() {
|
||||
static state instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
state previous_state;
|
||||
};
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,386 +0,0 @@
|
|||
/*
|
||||
pybind11/stl.h: Transparent conversion for STL data types
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <deque>
|
||||
#include <valarray>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
#endif
|
||||
|
||||
#ifdef __has_include
|
||||
// std::optional (but including it in c++14 mode isn't allowed)
|
||||
# if defined(PYBIND11_CPP17) && __has_include(<optional>)
|
||||
# include <optional>
|
||||
# define PYBIND11_HAS_OPTIONAL 1
|
||||
# endif
|
||||
// std::experimental::optional (but not allowed in c++11 mode)
|
||||
# if defined(PYBIND11_CPP14) && (__has_include(<experimental/optional>) && \
|
||||
!__has_include(<optional>))
|
||||
# include <experimental/optional>
|
||||
# define PYBIND11_HAS_EXP_OPTIONAL 1
|
||||
# endif
|
||||
// std::variant
|
||||
# if defined(PYBIND11_CPP17) && __has_include(<variant>)
|
||||
# include <variant>
|
||||
# define PYBIND11_HAS_VARIANT 1
|
||||
# endif
|
||||
#elif defined(_MSC_VER) && defined(PYBIND11_CPP17)
|
||||
# include <optional>
|
||||
# include <variant>
|
||||
# define PYBIND11_HAS_OPTIONAL 1
|
||||
# define PYBIND11_HAS_VARIANT 1
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for
|
||||
/// forwarding a container element). Typically used indirect via forwarded_type(), below.
|
||||
template <typename T, typename U>
|
||||
using forwarded_type = conditional_t<
|
||||
std::is_lvalue_reference<T>::value, remove_reference_t<U> &, remove_reference_t<U> &&>;
|
||||
|
||||
/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically
|
||||
/// used for forwarding a container's elements.
|
||||
template <typename T, typename U>
|
||||
forwarded_type<T, U> forward_like(U &&u) {
|
||||
return std::forward<detail::forwarded_type<T, U>>(std::forward<U>(u));
|
||||
}
|
||||
|
||||
template <typename Type, typename Key> struct set_caster {
|
||||
using type = Type;
|
||||
using key_conv = make_caster<Key>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<pybind11::set>(src))
|
||||
return false;
|
||||
auto s = reinterpret_borrow<pybind11::set>(src);
|
||||
value.clear();
|
||||
for (auto entry : s) {
|
||||
key_conv conv;
|
||||
if (!conv.load(entry, convert))
|
||||
return false;
|
||||
value.insert(cast_op<Key &&>(std::move(conv)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
if (!std::is_lvalue_reference<T>::value)
|
||||
policy = return_value_policy_override<Key>::policy(policy);
|
||||
pybind11::set s;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(key_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_ || !s.add(value_))
|
||||
return handle();
|
||||
}
|
||||
return s.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Key, typename Value> struct map_caster {
|
||||
using key_conv = make_caster<Key>;
|
||||
using value_conv = make_caster<Value>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<dict>(src))
|
||||
return false;
|
||||
auto d = reinterpret_borrow<dict>(src);
|
||||
value.clear();
|
||||
for (auto it : d) {
|
||||
key_conv kconv;
|
||||
value_conv vconv;
|
||||
if (!kconv.load(it.first.ptr(), convert) ||
|
||||
!vconv.load(it.second.ptr(), convert))
|
||||
return false;
|
||||
value.emplace(cast_op<Key &&>(std::move(kconv)), cast_op<Value &&>(std::move(vconv)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
dict d;
|
||||
return_value_policy policy_key = policy;
|
||||
return_value_policy policy_value = policy;
|
||||
if (!std::is_lvalue_reference<T>::value) {
|
||||
policy_key = return_value_policy_override<Key>::policy(policy_key);
|
||||
policy_value = return_value_policy_override<Value>::policy(policy_value);
|
||||
}
|
||||
for (auto &&kv : src) {
|
||||
auto key = reinterpret_steal<object>(key_conv::cast(forward_like<T>(kv.first), policy_key, parent));
|
||||
auto value = reinterpret_steal<object>(value_conv::cast(forward_like<T>(kv.second), policy_value, parent));
|
||||
if (!key || !value)
|
||||
return handle();
|
||||
d[key] = value;
|
||||
}
|
||||
return d.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Value> struct list_caster {
|
||||
using value_conv = make_caster<Value>;
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<sequence>(src) || isinstance<str>(src))
|
||||
return false;
|
||||
auto s = reinterpret_borrow<sequence>(src);
|
||||
value.clear();
|
||||
reserve_maybe(s, &value);
|
||||
for (auto it : s) {
|
||||
value_conv conv;
|
||||
if (!conv.load(it, convert))
|
||||
return false;
|
||||
value.push_back(cast_op<Value &&>(std::move(conv)));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T = Type,
|
||||
enable_if_t<std::is_same<decltype(std::declval<T>().reserve(0)), void>::value, int> = 0>
|
||||
void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); }
|
||||
void reserve_maybe(sequence, void *) { }
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
if (!std::is_lvalue_reference<T>::value)
|
||||
policy = return_value_policy_override<Value>::policy(policy);
|
||||
list l(src.size());
|
||||
size_t index = 0;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_)
|
||||
return handle();
|
||||
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
|
||||
}
|
||||
return l.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::vector<Type, Alloc>>
|
||||
: list_caster<std::vector<Type, Alloc>, Type> { };
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::deque<Type, Alloc>>
|
||||
: list_caster<std::deque<Type, Alloc>, Type> { };
|
||||
|
||||
template <typename Type, typename Alloc> struct type_caster<std::list<Type, Alloc>>
|
||||
: list_caster<std::list<Type, Alloc>, Type> { };
|
||||
|
||||
template <typename ArrayType, typename Value, bool Resizable, size_t Size = 0> struct array_caster {
|
||||
using value_conv = make_caster<Value>;
|
||||
|
||||
private:
|
||||
template <bool R = Resizable>
|
||||
bool require_size(enable_if_t<R, size_t> size) {
|
||||
if (value.size() != size)
|
||||
value.resize(size);
|
||||
return true;
|
||||
}
|
||||
template <bool R = Resizable>
|
||||
bool require_size(enable_if_t<!R, size_t> size) {
|
||||
return size == Size;
|
||||
}
|
||||
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
if (!isinstance<sequence>(src))
|
||||
return false;
|
||||
auto l = reinterpret_borrow<sequence>(src);
|
||||
if (!require_size(l.size()))
|
||||
return false;
|
||||
size_t ctr = 0;
|
||||
for (auto it : l) {
|
||||
value_conv conv;
|
||||
if (!conv.load(it, convert))
|
||||
return false;
|
||||
value[ctr++] = cast_op<Value &&>(std::move(conv));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static handle cast(T &&src, return_value_policy policy, handle parent) {
|
||||
list l(src.size());
|
||||
size_t index = 0;
|
||||
for (auto &&value : src) {
|
||||
auto value_ = reinterpret_steal<object>(value_conv::cast(forward_like<T>(value), policy, parent));
|
||||
if (!value_)
|
||||
return handle();
|
||||
PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference
|
||||
}
|
||||
return l.release();
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _<Resizable>(_(""), _("[") + _<Size>() + _("]")) + _("]"));
|
||||
};
|
||||
|
||||
template <typename Type, size_t Size> struct type_caster<std::array<Type, Size>>
|
||||
: array_caster<std::array<Type, Size>, Type, false, Size> { };
|
||||
|
||||
template <typename Type> struct type_caster<std::valarray<Type>>
|
||||
: array_caster<std::valarray<Type>, Type, true> { };
|
||||
|
||||
template <typename Key, typename Compare, typename Alloc> struct type_caster<std::set<Key, Compare, Alloc>>
|
||||
: set_caster<std::set<Key, Compare, Alloc>, Key> { };
|
||||
|
||||
template <typename Key, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_set<Key, Hash, Equal, Alloc>>
|
||||
: set_caster<std::unordered_set<Key, Hash, Equal, Alloc>, Key> { };
|
||||
|
||||
template <typename Key, typename Value, typename Compare, typename Alloc> struct type_caster<std::map<Key, Value, Compare, Alloc>>
|
||||
: map_caster<std::map<Key, Value, Compare, Alloc>, Key, Value> { };
|
||||
|
||||
template <typename Key, typename Value, typename Hash, typename Equal, typename Alloc> struct type_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>>
|
||||
: map_caster<std::unordered_map<Key, Value, Hash, Equal, Alloc>, Key, Value> { };
|
||||
|
||||
// This type caster is intended to be used for std::optional and std::experimental::optional
|
||||
template<typename T> struct optional_caster {
|
||||
using value_conv = make_caster<typename T::value_type>;
|
||||
|
||||
template <typename T_>
|
||||
static handle cast(T_ &&src, return_value_policy policy, handle parent) {
|
||||
if (!src)
|
||||
return none().inc_ref();
|
||||
policy = return_value_policy_override<typename T::value_type>::policy(policy);
|
||||
return value_conv::cast(*std::forward<T_>(src), policy, parent);
|
||||
}
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
if (!src) {
|
||||
return false;
|
||||
} else if (src.is_none()) {
|
||||
return true; // default-constructed value is already empty
|
||||
}
|
||||
value_conv inner_caster;
|
||||
if (!inner_caster.load(src, convert))
|
||||
return false;
|
||||
|
||||
value.emplace(cast_op<typename T::value_type &&>(std::move(inner_caster)));
|
||||
return true;
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]"));
|
||||
};
|
||||
|
||||
#if PYBIND11_HAS_OPTIONAL
|
||||
template<typename T> struct type_caster<std::optional<T>>
|
||||
: public optional_caster<std::optional<T>> {};
|
||||
|
||||
template<> struct type_caster<std::nullopt_t>
|
||||
: public void_caster<std::nullopt_t> {};
|
||||
#endif
|
||||
|
||||
#if PYBIND11_HAS_EXP_OPTIONAL
|
||||
template<typename T> struct type_caster<std::experimental::optional<T>>
|
||||
: public optional_caster<std::experimental::optional<T>> {};
|
||||
|
||||
template<> struct type_caster<std::experimental::nullopt_t>
|
||||
: public void_caster<std::experimental::nullopt_t> {};
|
||||
#endif
|
||||
|
||||
/// Visit a variant and cast any found type to Python
|
||||
struct variant_caster_visitor {
|
||||
return_value_policy policy;
|
||||
handle parent;
|
||||
|
||||
using result_type = handle; // required by boost::variant in C++11
|
||||
|
||||
template <typename T>
|
||||
result_type operator()(T &&src) const {
|
||||
return make_caster<T>::cast(std::forward<T>(src), policy, parent);
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar
|
||||
/// `namespace::variant` types which provide a `namespace::visit()` function are handled here
|
||||
/// automatically using argument-dependent lookup. Users can provide specializations for other
|
||||
/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`.
|
||||
template <template<typename...> class Variant>
|
||||
struct visit_helper {
|
||||
template <typename... Args>
|
||||
static auto call(Args &&...args) -> decltype(visit(std::forward<Args>(args)...)) {
|
||||
return visit(std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
/// Generic variant caster
|
||||
template <typename Variant> struct variant_caster;
|
||||
|
||||
template <template<typename...> class V, typename... Ts>
|
||||
struct variant_caster<V<Ts...>> {
|
||||
static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative.");
|
||||
|
||||
template <typename U, typename... Us>
|
||||
bool load_alternative(handle src, bool convert, type_list<U, Us...>) {
|
||||
auto caster = make_caster<U>();
|
||||
if (caster.load(src, convert)) {
|
||||
value = cast_op<U>(caster);
|
||||
return true;
|
||||
}
|
||||
return load_alternative(src, convert, type_list<Us...>{});
|
||||
}
|
||||
|
||||
bool load_alternative(handle, bool, type_list<>) { return false; }
|
||||
|
||||
bool load(handle src, bool convert) {
|
||||
// Do a first pass without conversions to improve constructor resolution.
|
||||
// E.g. `py::int_(1).cast<variant<double, int>>()` needs to fill the `int`
|
||||
// slot of the variant. Without two-pass loading `double` would be filled
|
||||
// because it appears first and a conversion is possible.
|
||||
if (convert && load_alternative(src, false, type_list<Ts...>{}))
|
||||
return true;
|
||||
return load_alternative(src, convert, type_list<Ts...>{});
|
||||
}
|
||||
|
||||
template <typename Variant>
|
||||
static handle cast(Variant &&src, return_value_policy policy, handle parent) {
|
||||
return visit_helper<V>::call(variant_caster_visitor{policy, parent},
|
||||
std::forward<Variant>(src));
|
||||
}
|
||||
|
||||
using Type = V<Ts...>;
|
||||
PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster<Ts>::name...) + _("]"));
|
||||
};
|
||||
|
||||
#if PYBIND11_HAS_VARIANT
|
||||
template <typename... Ts>
|
||||
struct type_caster<std::variant<Ts...>> : variant_caster<std::variant<Ts...>> { };
|
||||
#endif
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const handle &obj) {
|
||||
os << (std::string) str(obj);
|
||||
return os;
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(pop)
|
||||
#endif
|
|
@ -1,599 +0,0 @@
|
|||
/*
|
||||
pybind11/std_bind.h: Binding generators for STL data types
|
||||
|
||||
Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "detail/common.h"
|
||||
#include "operators.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
|
||||
NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/* SFINAE helper class used by 'is_comparable */
|
||||
template <typename T> struct container_traits {
|
||||
template <typename T2> static std::true_type test_comparable(decltype(std::declval<const T2 &>() == std::declval<const T2 &>())*);
|
||||
template <typename T2> static std::false_type test_comparable(...);
|
||||
template <typename T2> static std::true_type test_value(typename T2::value_type *);
|
||||
template <typename T2> static std::false_type test_value(...);
|
||||
template <typename T2> static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *);
|
||||
template <typename T2> static std::false_type test_pair(...);
|
||||
|
||||
static constexpr const bool is_comparable = std::is_same<std::true_type, decltype(test_comparable<T>(nullptr))>::value;
|
||||
static constexpr const bool is_pair = std::is_same<std::true_type, decltype(test_pair<T>(nullptr, nullptr))>::value;
|
||||
static constexpr const bool is_vector = std::is_same<std::true_type, decltype(test_value<T>(nullptr))>::value;
|
||||
static constexpr const bool is_element = !is_pair && !is_vector;
|
||||
};
|
||||
|
||||
/* Default: is_comparable -> std::false_type */
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct is_comparable : std::false_type { };
|
||||
|
||||
/* For non-map data structures, check whether operator== can be instantiated */
|
||||
template <typename T>
|
||||
struct is_comparable<
|
||||
T, enable_if_t<container_traits<T>::is_element &&
|
||||
container_traits<T>::is_comparable>>
|
||||
: std::true_type { };
|
||||
|
||||
/* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */
|
||||
template <typename T>
|
||||
struct is_comparable<T, enable_if_t<container_traits<T>::is_vector>> {
|
||||
static constexpr const bool value =
|
||||
is_comparable<typename T::value_type>::value;
|
||||
};
|
||||
|
||||
/* For pairs, recursively check the two data types */
|
||||
template <typename T>
|
||||
struct is_comparable<T, enable_if_t<container_traits<T>::is_pair>> {
|
||||
static constexpr const bool value =
|
||||
is_comparable<typename T::first_type>::value &&
|
||||
is_comparable<typename T::second_type>::value;
|
||||
};
|
||||
|
||||
/* Fallback functions */
|
||||
template <typename, typename, typename... Args> void vector_if_copy_constructible(const Args &...) { }
|
||||
template <typename, typename, typename... Args> void vector_if_equal_operator(const Args &...) { }
|
||||
template <typename, typename, typename... Args> void vector_if_insertion_operator(const Args &...) { }
|
||||
template <typename, typename, typename... Args> void vector_modifiers(const Args &...) { }
|
||||
|
||||
template<typename Vector, typename Class_>
|
||||
void vector_if_copy_constructible(enable_if_t<is_copy_constructible<Vector>::value, Class_> &cl) {
|
||||
cl.def(init<const Vector &>(), "Copy constructor");
|
||||
}
|
||||
|
||||
template<typename Vector, typename Class_>
|
||||
void vector_if_equal_operator(enable_if_t<is_comparable<Vector>::value, Class_> &cl) {
|
||||
using T = typename Vector::value_type;
|
||||
|
||||
cl.def(self == self);
|
||||
cl.def(self != self);
|
||||
|
||||
cl.def("count",
|
||||
[](const Vector &v, const T &x) {
|
||||
return std::count(v.begin(), v.end(), x);
|
||||
},
|
||||
arg("x"),
|
||||
"Return the number of times ``x`` appears in the list"
|
||||
);
|
||||
|
||||
cl.def("remove", [](Vector &v, const T &x) {
|
||||
auto p = std::find(v.begin(), v.end(), x);
|
||||
if (p != v.end())
|
||||
v.erase(p);
|
||||
else
|
||||
throw value_error();
|
||||
},
|
||||
arg("x"),
|
||||
"Remove the first item from the list whose value is x. "
|
||||
"It is an error if there is no such item."
|
||||
);
|
||||
|
||||
cl.def("__contains__",
|
||||
[](const Vector &v, const T &x) {
|
||||
return std::find(v.begin(), v.end(), x) != v.end();
|
||||
},
|
||||
arg("x"),
|
||||
"Return true the container contains ``x``"
|
||||
);
|
||||
}
|
||||
|
||||
// Vector modifiers -- requires a copyable vector_type:
|
||||
// (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems
|
||||
// silly to allow deletion but not insertion, so include them here too.)
|
||||
template <typename Vector, typename Class_>
|
||||
void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_type>::value, Class_> &cl) {
|
||||
using T = typename Vector::value_type;
|
||||
using SizeType = typename Vector::size_type;
|
||||
using DiffType = typename Vector::difference_type;
|
||||
|
||||
cl.def("append",
|
||||
[](Vector &v, const T &value) { v.push_back(value); },
|
||||
arg("x"),
|
||||
"Add an item to the end of the list");
|
||||
|
||||
cl.def(init([](iterable it) {
|
||||
auto v = std::unique_ptr<Vector>(new Vector());
|
||||
v->reserve(len(it));
|
||||
for (handle h : it)
|
||||
v->push_back(h.cast<T>());
|
||||
return v.release();
|
||||
}));
|
||||
|
||||
cl.def("extend",
|
||||
[](Vector &v, const Vector &src) {
|
||||
v.insert(v.end(), src.begin(), src.end());
|
||||
},
|
||||
arg("L"),
|
||||
"Extend the list by appending all the items in the given list"
|
||||
);
|
||||
|
||||
cl.def("insert",
|
||||
[](Vector &v, SizeType i, const T &x) {
|
||||
if (i > v.size())
|
||||
throw index_error();
|
||||
v.insert(v.begin() + (DiffType) i, x);
|
||||
},
|
||||
arg("i") , arg("x"),
|
||||
"Insert an item at a given position."
|
||||
);
|
||||
|
||||
cl.def("pop",
|
||||
[](Vector &v) {
|
||||
if (v.empty())
|
||||
throw index_error();
|
||||
T t = v.back();
|
||||
v.pop_back();
|
||||
return t;
|
||||
},
|
||||
"Remove and return the last item"
|
||||
);
|
||||
|
||||
cl.def("pop",
|
||||
[](Vector &v, SizeType i) {
|
||||
if (i >= v.size())
|
||||
throw index_error();
|
||||
T t = v[i];
|
||||
v.erase(v.begin() + (DiffType) i);
|
||||
return t;
|
||||
},
|
||||
arg("i"),
|
||||
"Remove and return the item at index ``i``"
|
||||
);
|
||||
|
||||
cl.def("__setitem__",
|
||||
[](Vector &v, SizeType i, const T &t) {
|
||||
if (i >= v.size())
|
||||
throw index_error();
|
||||
v[i] = t;
|
||||
}
|
||||
);
|
||||
|
||||
/// Slicing protocol
|
||||
cl.def("__getitem__",
|
||||
[](const Vector &v, slice slice) -> Vector * {
|
||||
size_t start, stop, step, slicelength;
|
||||
|
||||
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
|
||||
throw error_already_set();
|
||||
|
||||
Vector *seq = new Vector();
|
||||
seq->reserve((size_t) slicelength);
|
||||
|
||||
for (size_t i=0; i<slicelength; ++i) {
|
||||
seq->push_back(v[start]);
|
||||
start += step;
|
||||
}
|
||||
return seq;
|
||||
},
|
||||
arg("s"),
|
||||
"Retrieve list elements using a slice object"
|
||||
);
|
||||
|
||||
cl.def("__setitem__",
|
||||
[](Vector &v, slice slice, const Vector &value) {
|
||||
size_t start, stop, step, slicelength;
|
||||
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
|
||||
throw error_already_set();
|
||||
|
||||
if (slicelength != value.size())
|
||||
throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");
|
||||
|
||||
for (size_t i=0; i<slicelength; ++i) {
|
||||
v[start] = value[i];
|
||||
start += step;
|
||||
}
|
||||
},
|
||||
"Assign list elements using a slice object"
|
||||
);
|
||||
|
||||
cl.def("__delitem__",
|
||||
[](Vector &v, SizeType i) {
|
||||
if (i >= v.size())
|
||||
throw index_error();
|
||||
v.erase(v.begin() + DiffType(i));
|
||||
},
|
||||
"Delete the list elements at index ``i``"
|
||||
);
|
||||
|
||||
cl.def("__delitem__",
|
||||
[](Vector &v, slice slice) {
|
||||
size_t start, stop, step, slicelength;
|
||||
|
||||
if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
|
||||
throw error_already_set();
|
||||
|
||||
if (step == 1 && false) {
|
||||
v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength));
|
||||
} else {
|
||||
for (size_t i = 0; i < slicelength; ++i) {
|
||||
v.erase(v.begin() + DiffType(start));
|
||||
start += step - 1;
|
||||
}
|
||||
}
|
||||
},
|
||||
"Delete list elements using a slice object"
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
// If the type has an operator[] that doesn't return a reference (most notably std::vector<bool>),
|
||||
// we have to access by copying; otherwise we return by reference.
|
||||
template <typename Vector> using vector_needs_copy = negation<
|
||||
std::is_same<decltype(std::declval<Vector>()[typename Vector::size_type()]), typename Vector::value_type &>>;
|
||||
|
||||
// The usual case: access and iterate by reference
|
||||
template <typename Vector, typename Class_>
|
||||
void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) {
|
||||
using T = typename Vector::value_type;
|
||||
using SizeType = typename Vector::size_type;
|
||||
using ItType = typename Vector::iterator;
|
||||
|
||||
cl.def("__getitem__",
|
||||
[](Vector &v, SizeType i) -> T & {
|
||||
if (i >= v.size())
|
||||
throw index_error();
|
||||
return v[i];
|
||||
},
|
||||
return_value_policy::reference_internal // ref + keepalive
|
||||
);
|
||||
|
||||
cl.def("__iter__",
|
||||
[](Vector &v) {
|
||||
return make_iterator<
|
||||
return_value_policy::reference_internal, ItType, ItType, T&>(
|
||||
v.begin(), v.end());
|
||||
},
|
||||
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
|
||||
);
|
||||
}
|
||||
|
||||
// The case for special objects, like std::vector<bool>, that have to be returned-by-copy:
|
||||
template <typename Vector, typename Class_>
|
||||
void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) {
|
||||
using T = typename Vector::value_type;
|
||||
using SizeType = typename Vector::size_type;
|
||||
using ItType = typename Vector::iterator;
|
||||
cl.def("__getitem__",
|
||||
[](const Vector &v, SizeType i) -> T {
|
||||
if (i >= v.size())
|
||||
throw index_error();
|
||||
return v[i];
|
||||
}
|
||||
);
|
||||
|
||||
cl.def("__iter__",
|
||||
[](Vector &v) {
|
||||
return make_iterator<
|
||||
return_value_policy::copy, ItType, ItType, T>(
|
||||
v.begin(), v.end());
|
||||
},
|
||||
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
|
||||
);
|
||||
}
|
||||
|
||||
template <typename Vector, typename Class_> auto vector_if_insertion_operator(Class_ &cl, std::string const &name)
|
||||
-> decltype(std::declval<std::ostream&>() << std::declval<typename Vector::value_type>(), void()) {
|
||||
using size_type = typename Vector::size_type;
|
||||
|
||||
cl.def("__repr__",
|
||||
[name](Vector &v) {
|
||||
std::ostringstream s;
|
||||
s << name << '[';
|
||||
for (size_type i=0; i < v.size(); ++i) {
|
||||
s << v[i];
|
||||
if (i != v.size() - 1)
|
||||
s << ", ";
|
||||
}
|
||||
s << ']';
|
||||
return s.str();
|
||||
},
|
||||
"Return the canonical string representation of this list."
|
||||
);
|
||||
}
|
||||
|
||||
// Provide the buffer interface for vectors if we have data() and we have a format for it
|
||||
// GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer
|
||||
template <typename Vector, typename = void>
|
||||
struct vector_has_data_and_format : std::false_type {};
|
||||
template <typename Vector>
|
||||
struct vector_has_data_and_format<Vector, enable_if_t<std::is_same<decltype(format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), typename Vector::value_type*>::value>> : std::true_type {};
|
||||
|
||||
// Add the buffer interface to a vector
|
||||
template <typename Vector, typename Class_, typename... Args>
|
||||
enable_if_t<detail::any_of<std::is_same<Args, buffer_protocol>...>::value>
|
||||
vector_buffer(Class_& cl) {
|
||||
using T = typename Vector::value_type;
|
||||
|
||||
static_assert(vector_has_data_and_format<Vector>::value, "There is not an appropriate format descriptor for this vector");
|
||||
|
||||
// numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here
|
||||
format_descriptor<T>::format();
|
||||
|
||||
cl.def_buffer([](Vector& v) -> buffer_info {
|
||||
return buffer_info(v.data(), static_cast<ssize_t>(sizeof(T)), format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
|
||||
});
|
||||
|
||||
cl.def(init([](buffer buf) {
|
||||
auto info = buf.request();
|
||||
if (info.ndim != 1 || info.strides[0] % static_cast<ssize_t>(sizeof(T)))
|
||||
throw type_error("Only valid 1D buffers can be copied to a vector");
|
||||
if (!detail::compare_buffer_info<T>::compare(info) || (ssize_t) sizeof(T) != info.itemsize)
|
||||
throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
|
||||
|
||||
auto vec = std::unique_ptr<Vector>(new Vector());
|
||||
vec->reserve((size_t) info.shape[0]);
|
||||
T *p = static_cast<T*>(info.ptr);
|
||||
ssize_t step = info.strides[0] / static_cast<ssize_t>(sizeof(T));
|
||||
T *end = p + info.shape[0] * step;
|
||||
for (; p != end; p += step)
|
||||
vec->push_back(*p);
|
||||
return vec.release();
|
||||
}));
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename Vector, typename Class_, typename... Args>
|
||||
enable_if_t<!detail::any_of<std::is_same<Args, buffer_protocol>...>::value> vector_buffer(Class_&) {}
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
//
|
||||
// std::vector
|
||||
//
|
||||
template <typename Vector, typename holder_type = std::unique_ptr<Vector>, typename... Args>
|
||||
class_<Vector, holder_type> bind_vector(handle scope, std::string const &name, Args&&... args) {
|
||||
using Class_ = class_<Vector, holder_type>;
|
||||
|
||||
// If the value_type is unregistered (e.g. a converting type) or is itself registered
|
||||
// module-local then make the vector binding module-local as well:
|
||||
using vtype = typename Vector::value_type;
|
||||
auto vtype_info = detail::get_type_info(typeid(vtype));
|
||||
bool local = !vtype_info || vtype_info->module_local;
|
||||
|
||||
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
|
||||
|
||||
// Declare the buffer interface if a buffer_protocol() is passed in
|
||||
detail::vector_buffer<Vector, Class_, Args...>(cl);
|
||||
|
||||
cl.def(init<>());
|
||||
|
||||
// Register copy constructor (if possible)
|
||||
detail::vector_if_copy_constructible<Vector, Class_>(cl);
|
||||
|
||||
// Register comparison-related operators and functions (if possible)
|
||||
detail::vector_if_equal_operator<Vector, Class_>(cl);
|
||||
|
||||
// Register stream insertion operator (if possible)
|
||||
detail::vector_if_insertion_operator<Vector, Class_>(cl, name);
|
||||
|
||||
// Modifiers require copyable vector value type
|
||||
detail::vector_modifiers<Vector, Class_>(cl);
|
||||
|
||||
// Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive
|
||||
detail::vector_accessor<Vector, Class_>(cl);
|
||||
|
||||
cl.def("__bool__",
|
||||
[](const Vector &v) -> bool {
|
||||
return !v.empty();
|
||||
},
|
||||
"Check whether the list is nonempty"
|
||||
);
|
||||
|
||||
cl.def("__len__", &Vector::size);
|
||||
|
||||
|
||||
|
||||
|
||||
#if 0
|
||||
// C++ style functions deprecated, leaving it here as an example
|
||||
cl.def(init<size_type>());
|
||||
|
||||
cl.def("resize",
|
||||
(void (Vector::*) (size_type count)) & Vector::resize,
|
||||
"changes the number of elements stored");
|
||||
|
||||
cl.def("erase",
|
||||
[](Vector &v, SizeType i) {
|
||||
if (i >= v.size())
|
||||
throw index_error();
|
||||
v.erase(v.begin() + i);
|
||||
}, "erases element at index ``i``");
|
||||
|
||||
cl.def("empty", &Vector::empty, "checks whether the container is empty");
|
||||
cl.def("size", &Vector::size, "returns the number of elements");
|
||||
cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end");
|
||||
cl.def("pop_back", &Vector::pop_back, "removes the last element");
|
||||
|
||||
cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements");
|
||||
cl.def("reserve", &Vector::reserve, "reserves storage");
|
||||
cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage");
|
||||
cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory");
|
||||
|
||||
cl.def("clear", &Vector::clear, "clears the contents");
|
||||
cl.def("swap", &Vector::swap, "swaps the contents");
|
||||
|
||||
cl.def("front", [](Vector &v) {
|
||||
if (v.size()) return v.front();
|
||||
else throw index_error();
|
||||
}, "access the first element");
|
||||
|
||||
cl.def("back", [](Vector &v) {
|
||||
if (v.size()) return v.back();
|
||||
else throw index_error();
|
||||
}, "access the last element ");
|
||||
|
||||
#endif
|
||||
|
||||
return cl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
//
|
||||
// std::map, std::unordered_map
|
||||
//
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
/* Fallback functions */
|
||||
template <typename, typename, typename... Args> void map_if_insertion_operator(const Args &...) { }
|
||||
template <typename, typename, typename... Args> void map_assignment(const Args &...) { }
|
||||
|
||||
// Map assignment when copy-assignable: just copy the value
|
||||
template <typename Map, typename Class_>
|
||||
void map_assignment(enable_if_t<std::is_copy_assignable<typename Map::mapped_type>::value, Class_> &cl) {
|
||||
using KeyType = typename Map::key_type;
|
||||
using MappedType = typename Map::mapped_type;
|
||||
|
||||
cl.def("__setitem__",
|
||||
[](Map &m, const KeyType &k, const MappedType &v) {
|
||||
auto it = m.find(k);
|
||||
if (it != m.end()) it->second = v;
|
||||
else m.emplace(k, v);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting
|
||||
template<typename Map, typename Class_>
|
||||
void map_assignment(enable_if_t<
|
||||
!std::is_copy_assignable<typename Map::mapped_type>::value &&
|
||||
is_copy_constructible<typename Map::mapped_type>::value,
|
||||
Class_> &cl) {
|
||||
using KeyType = typename Map::key_type;
|
||||
using MappedType = typename Map::mapped_type;
|
||||
|
||||
cl.def("__setitem__",
|
||||
[](Map &m, const KeyType &k, const MappedType &v) {
|
||||
// We can't use m[k] = v; because value type might not be default constructable
|
||||
auto r = m.emplace(k, v);
|
||||
if (!r.second) {
|
||||
// value type is not copy assignable so the only way to insert it is to erase it first...
|
||||
m.erase(r.first);
|
||||
m.emplace(k, v);
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
template <typename Map, typename Class_> auto map_if_insertion_operator(Class_ &cl, std::string const &name)
|
||||
-> decltype(std::declval<std::ostream&>() << std::declval<typename Map::key_type>() << std::declval<typename Map::mapped_type>(), void()) {
|
||||
|
||||
cl.def("__repr__",
|
||||
[name](Map &m) {
|
||||
std::ostringstream s;
|
||||
s << name << '{';
|
||||
bool f = false;
|
||||
for (auto const &kv : m) {
|
||||
if (f)
|
||||
s << ", ";
|
||||
s << kv.first << ": " << kv.second;
|
||||
f = true;
|
||||
}
|
||||
s << '}';
|
||||
return s.str();
|
||||
},
|
||||
"Return the canonical string representation of this map."
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
template <typename Map, typename holder_type = std::unique_ptr<Map>, typename... Args>
|
||||
class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&... args) {
|
||||
using KeyType = typename Map::key_type;
|
||||
using MappedType = typename Map::mapped_type;
|
||||
using Class_ = class_<Map, holder_type>;
|
||||
|
||||
// If either type is a non-module-local bound type then make the map binding non-local as well;
|
||||
// otherwise (e.g. both types are either module-local or converting) the map will be
|
||||
// module-local.
|
||||
auto tinfo = detail::get_type_info(typeid(MappedType));
|
||||
bool local = !tinfo || tinfo->module_local;
|
||||
if (local) {
|
||||
tinfo = detail::get_type_info(typeid(KeyType));
|
||||
local = !tinfo || tinfo->module_local;
|
||||
}
|
||||
|
||||
Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
|
||||
|
||||
cl.def(init<>());
|
||||
|
||||
// Register stream insertion operator (if possible)
|
||||
detail::map_if_insertion_operator<Map, Class_>(cl, name);
|
||||
|
||||
cl.def("__bool__",
|
||||
[](const Map &m) -> bool { return !m.empty(); },
|
||||
"Check whether the map is nonempty"
|
||||
);
|
||||
|
||||
cl.def("__iter__",
|
||||
[](Map &m) { return make_key_iterator(m.begin(), m.end()); },
|
||||
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
|
||||
);
|
||||
|
||||
cl.def("items",
|
||||
[](Map &m) { return make_iterator(m.begin(), m.end()); },
|
||||
keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
|
||||
);
|
||||
|
||||
cl.def("__getitem__",
|
||||
[](Map &m, const KeyType &k) -> MappedType & {
|
||||
auto it = m.find(k);
|
||||
if (it == m.end())
|
||||
throw key_error();
|
||||
return it->second;
|
||||
},
|
||||
return_value_policy::reference_internal // ref + keepalive
|
||||
);
|
||||
|
||||
// Assignment provided only if the type is copyable
|
||||
detail::map_assignment<Map, Class_>(cl);
|
||||
|
||||
cl.def("__delitem__",
|
||||
[](Map &m, const KeyType &k) {
|
||||
auto it = m.find(k);
|
||||
if (it == m.end())
|
||||
throw key_error();
|
||||
m.erase(it);
|
||||
}
|
||||
);
|
||||
|
||||
cl.def("__len__", &Map::size);
|
||||
|
||||
return cl;
|
||||
}
|
||||
|
||||
NAMESPACE_END(PYBIND11_NAMESPACE)
|
|
@ -1,53 +0,0 @@
|
|||
/*
|
||||
pybind11/typeid.h: Compiler-independent access to type identifiers
|
||||
|
||||
Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
#if defined(__GNUG__)
|
||||
#include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
/// Erase all occurrences of a substring
|
||||
inline void erase_all(std::string &string, const std::string &search) {
|
||||
for (size_t pos = 0;;) {
|
||||
pos = string.find(search, pos);
|
||||
if (pos == std::string::npos) break;
|
||||
string.erase(pos, search.length());
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
|
||||
#if defined(__GNUG__)
|
||||
int status = 0;
|
||||
std::unique_ptr<char, void (*)(void *)> res {
|
||||
abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
|
||||
if (status == 0)
|
||||
name = res.get();
|
||||
#else
|
||||
detail::erase_all(name, "class ");
|
||||
detail::erase_all(name, "struct ");
|
||||
detail::erase_all(name, "enum ");
|
||||
#endif
|
||||
detail::erase_all(name, "pybind11::");
|
||||
}
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// Return a string representation of a C++ type
|
||||
template <typename T> static std::string type_id() {
|
||||
std::string name(typeid(T).name());
|
||||
detail::clean_type_id(name);
|
||||
return name;
|
||||
}
|
||||
|
||||
NAMESPACE_END(pybind11)
|
|
@ -1,195 +0,0 @@
|
|||
// This implementation is from https://github.com/WenmuZhou/PAN.pytorch/blob/master/post_processing/pse.cpp
|
||||
|
||||
#include <queue>
|
||||
#include <math.h>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "include/pybind11/stl_bind.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
||||
namespace panet{
|
||||
py::array_t<uint8_t> assign_pixels(
|
||||
py::array_t<uint8_t, py::array::c_style> text,
|
||||
py::array_t<float, py::array::c_style> similarity_vectors,
|
||||
py::array_t<int32_t, py::array::c_style> label_map,
|
||||
int label_num,
|
||||
float dis_threshold = 0.8)
|
||||
{
|
||||
auto pbuf_text = text.request();
|
||||
auto pbuf_similarity_vectors = similarity_vectors.request();
|
||||
auto pbuf_label_map = label_map.request();
|
||||
if (pbuf_label_map.ndim != 2 || pbuf_label_map.shape[0]==0 || pbuf_label_map.shape[1]==0)
|
||||
throw std::runtime_error("label map must have a shape of (h>0, w>0)");
|
||||
int h = pbuf_label_map.shape[0];
|
||||
int w = pbuf_label_map.shape[1];
|
||||
if (pbuf_similarity_vectors.ndim != 3 || pbuf_similarity_vectors.shape[0]!=h || pbuf_similarity_vectors.shape[1]!=w || pbuf_similarity_vectors.shape[2]!=4 ||
|
||||
pbuf_text.shape[0]!=h || pbuf_text.shape[1]!=w)
|
||||
throw std::runtime_error("similarity_vectors must have a shape of (h,w,4) and text must have a shape of (h,w,4)");
|
||||
//初始化结果
|
||||
auto res = py::array_t<uint8_t>(pbuf_text.size);
|
||||
auto pbuf_res = res.request();
|
||||
// 获取 text similarity_vectors 和 label_map的指针
|
||||
auto ptr_label_map = static_cast<int32_t *>(pbuf_label_map.ptr);
|
||||
auto ptr_text = static_cast<uint8_t *>(pbuf_text.ptr);
|
||||
auto ptr_similarity_vectors = static_cast<float *>(pbuf_similarity_vectors.ptr);
|
||||
auto ptr_res = static_cast<uint8_t *>(pbuf_res.ptr);
|
||||
|
||||
std::queue<std::tuple<int, int, int32_t>> q;
|
||||
// 计算各个kernel的similarity_vectors
|
||||
std::vector<std::array<float, 5>> kernel_vector(label_num);
|
||||
|
||||
// 文本像素入队列
|
||||
for (int i = 0; i<h; i++)
|
||||
{
|
||||
auto p_label_map = ptr_label_map + i*w;
|
||||
auto p_res = ptr_res + i*w;
|
||||
auto p_similarity_vectors = ptr_similarity_vectors + i*w*4;
|
||||
for(int j = 0, k = 0; j<w && k < w * 4; j++,k+=4)
|
||||
{
|
||||
int32_t label = p_label_map[j];
|
||||
if (label>0)
|
||||
{
|
||||
kernel_vector[label][0] += p_similarity_vectors[k];
|
||||
kernel_vector[label][1] += p_similarity_vectors[k+1];
|
||||
kernel_vector[label][2] += p_similarity_vectors[k+2];
|
||||
kernel_vector[label][3] += p_similarity_vectors[k+3];
|
||||
kernel_vector[label][4] += 1;
|
||||
q.push(std::make_tuple(i, j, label));
|
||||
}
|
||||
p_res[j] = label;
|
||||
}
|
||||
}
|
||||
|
||||
for(int i=0;i<label_num;i++)
|
||||
{
|
||||
for (int j=0;j<4;j++)
|
||||
{
|
||||
kernel_vector[i][j] /= kernel_vector[i][4];
|
||||
}
|
||||
}
|
||||
|
||||
int dx[4] = {-1, 1, 0, 0};
|
||||
int dy[4] = {0, 0, -1, 1};
|
||||
while(!q.empty()){
|
||||
//get each queue menber in q
|
||||
auto q_n = q.front();
|
||||
q.pop();
|
||||
int y = std::get<0>(q_n);
|
||||
int x = std::get<1>(q_n);
|
||||
int32_t l = std::get<2>(q_n);
|
||||
//store the edge pixel after one expansion
|
||||
auto kernel_cv = kernel_vector[l];
|
||||
for (int idx=0; idx<4; idx++)
|
||||
{
|
||||
int tmpy = y + dy[idx];
|
||||
int tmpx = x + dx[idx];
|
||||
auto p_res = ptr_res + tmpy*w;
|
||||
if (tmpy<0 || tmpy>=h || tmpx<0 || tmpx>=w)
|
||||
continue;
|
||||
if (!ptr_text[tmpy*w+tmpx] || p_res[tmpx]>0)
|
||||
continue;
|
||||
// 计算距离
|
||||
float dis = 0;
|
||||
auto p_similarity_vectors = ptr_similarity_vectors + tmpy * w*4;
|
||||
for(size_t i=0;i<4;i++)
|
||||
{
|
||||
dis += pow(kernel_cv[i] - p_similarity_vectors[tmpx*4 + i],2);
|
||||
}
|
||||
dis = sqrt(dis);
|
||||
if(dis >= dis_threshold)
|
||||
continue;
|
||||
q.push(std::make_tuple(tmpy, tmpx, l));
|
||||
p_res[tmpx]=l;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::map<int,std::vector<float>> estimate_text_confidence(
|
||||
py::array_t<int32_t, py::array::c_style> label_map,
|
||||
py::array_t<float, py::array::c_style> score_map,
|
||||
int label_num)
|
||||
{
|
||||
auto pbuf_label_map = label_map.request();
|
||||
auto pbuf_score_map = score_map.request();
|
||||
auto ptr_label_map = static_cast<int32_t *>(pbuf_label_map.ptr);
|
||||
auto ptr_score_map = static_cast<float *>(pbuf_score_map.ptr);
|
||||
int h = pbuf_label_map.shape[0];
|
||||
int w = pbuf_label_map.shape[1];
|
||||
|
||||
std::map<int,std::vector<float>> point_dict;
|
||||
std::vector<std::vector<float>> point_vector;
|
||||
for(int i=0;i<label_num;i++)
|
||||
{
|
||||
std::vector<float> point;
|
||||
point.push_back(0);
|
||||
point.push_back(0);
|
||||
point_vector.push_back(point);
|
||||
}
|
||||
for (int i = 0; i<h; i++)
|
||||
{
|
||||
auto p_label_map = ptr_label_map + i*w;
|
||||
auto p_score_map = ptr_score_map + i*w;
|
||||
for(int j = 0; j<w; j++)
|
||||
{
|
||||
int32_t label = p_label_map[j];
|
||||
if(label==0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
float score = p_score_map[j];
|
||||
point_vector[label][0] += score;
|
||||
point_vector[label][1] += 1;
|
||||
point_vector[label].push_back(j);
|
||||
point_vector[label].push_back(i);
|
||||
}
|
||||
}
|
||||
for(int i=0;i<label_num;i++)
|
||||
{
|
||||
if(point_vector[i].size() > 2)
|
||||
{
|
||||
point_vector[i][0] /= point_vector[i][1];
|
||||
point_dict[i] = point_vector[i];
|
||||
}
|
||||
}
|
||||
return point_dict;
|
||||
}
|
||||
std::vector<int> get_pixel_num(
|
||||
py::array_t<int32_t, py::array::c_style> label_map,
|
||||
int label_num)
|
||||
{
|
||||
auto pbuf_label_map = label_map.request();
|
||||
auto ptr_label_map = static_cast<int32_t *>(pbuf_label_map.ptr);
|
||||
int h = pbuf_label_map.shape[0];
|
||||
int w = pbuf_label_map.shape[1];
|
||||
|
||||
std::vector<int> point_vector;
|
||||
for(int i=0;i<label_num;i++){
|
||||
point_vector.push_back(0);
|
||||
}
|
||||
for (int i = 0; i<h; i++){
|
||||
auto p_label_map = ptr_label_map + i*w;
|
||||
for(int j = 0; j<w; j++){
|
||||
int32_t label = p_label_map[j];
|
||||
if(label==0){
|
||||
continue;
|
||||
}
|
||||
point_vector[label] += 1;
|
||||
}
|
||||
}
|
||||
return point_vector;
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(pan, m){
|
||||
m.def("assign_pixels", &panet::assign_pixels, " assign pixels to text kernels", py::arg("text"), py::arg("similarity_vectors"), py::arg("label_map"), py::arg("label_num"), py::arg("dis_threshold")=0.8);
|
||||
m.def("estimate_text_confidence", &panet::estimate_text_confidence, " estimate average confidence for text instances", py::arg("label_map"), py::arg("score_map"), py::arg("label_num"));
|
||||
m.def("get_pixel_num", &panet::get_pixel_num, " get each text instance pixel number", py::arg("label_map"), py::arg("label_num"));
|
||||
}
|
|
@ -1,128 +0,0 @@
|
|||
// This implementation is from https://github.com/whai362/PSENet/blob/master/pse/adaptor.cpp
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace pse_adaptor {
|
||||
|
||||
class Point2d {
|
||||
public:
|
||||
int x;
|
||||
int y;
|
||||
|
||||
Point2d() : x(0), y(0)
|
||||
{}
|
||||
|
||||
Point2d(int _x, int _y) : x(_x), y(_y)
|
||||
{}
|
||||
};
|
||||
|
||||
void growing_text_line(const int *data,
|
||||
vector<pybind11::ssize_t> &data_shape,
|
||||
const int *label_map,
|
||||
vector<pybind11::ssize_t> &label_shape,
|
||||
int &label_num,
|
||||
float &min_area,
|
||||
vector<vector<int>> &text_line) {
|
||||
std::vector<int> area(label_num + 1);
|
||||
for (int x = 0; x < label_shape[0]; ++x) {
|
||||
for (int y = 0; y < label_shape[1]; ++y) {
|
||||
int label = label_map[x * label_shape[1] + y];
|
||||
if (label == 0) continue;
|
||||
area[label] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
queue<Point2d> queue, next_queue;
|
||||
for (int x = 0; x < label_shape[0]; ++x) {
|
||||
vector<int> row(label_shape[1]);
|
||||
for (int y = 0; y < label_shape[1]; ++y) {
|
||||
int label = label_map[x * label_shape[1] + y];
|
||||
if (label == 0) continue;
|
||||
if (area[label] < min_area) continue;
|
||||
|
||||
Point2d point(x, y);
|
||||
queue.push(point);
|
||||
row[y] = label;
|
||||
}
|
||||
text_line.emplace_back(row);
|
||||
}
|
||||
|
||||
int dx[] = {-1, 1, 0, 0};
|
||||
int dy[] = {0, 0, -1, 1};
|
||||
|
||||
for (int kernel_id = data_shape[0] - 2; kernel_id >= 0; --kernel_id) {
|
||||
while (!queue.empty()) {
|
||||
Point2d point = queue.front();
|
||||
queue.pop();
|
||||
int x = point.x;
|
||||
int y = point.y;
|
||||
int label = text_line[x][y];
|
||||
|
||||
bool is_edge = true;
|
||||
for (int d = 0; d < 4; ++d) {
|
||||
int tmp_x = x + dx[d];
|
||||
int tmp_y = y + dy[d];
|
||||
|
||||
if (tmp_x < 0 || tmp_x >= (int)text_line.size()) continue;
|
||||
if (tmp_y < 0 || tmp_y >= (int)text_line[1].size()) continue;
|
||||
int kernel_value = data[kernel_id * data_shape[1] * data_shape[2] + tmp_x * data_shape[2] + tmp_y];
|
||||
if (kernel_value == 0) continue;
|
||||
if (text_line[tmp_x][tmp_y] > 0) continue;
|
||||
|
||||
Point2d point(tmp_x, tmp_y);
|
||||
queue.push(point);
|
||||
text_line[tmp_x][tmp_y] = label;
|
||||
is_edge = false;
|
||||
}
|
||||
|
||||
if (is_edge) {
|
||||
next_queue.push(point);
|
||||
}
|
||||
}
|
||||
swap(queue, next_queue);
|
||||
}
|
||||
}
|
||||
|
||||
vector<vector<int>> pse(py::array_t<int, py::array::c_style | py::array::forcecast> quad_n9,
|
||||
float min_area,
|
||||
py::array_t<int32_t, py::array::c_style> label_map,
|
||||
int label_num) {
|
||||
auto buf = quad_n9.request();
|
||||
auto data = static_cast<int *>(buf.ptr);
|
||||
vector<pybind11::ssize_t> data_shape = buf.shape;
|
||||
|
||||
auto buf_label_map = label_map.request();
|
||||
auto data_label_map = static_cast<int32_t *>(buf_label_map.ptr);
|
||||
vector<pybind11::ssize_t> label_map_shape = buf_label_map.shape;
|
||||
|
||||
vector<vector<int>> text_line;
|
||||
|
||||
growing_text_line(data,
|
||||
data_shape,
|
||||
data_label_map,
|
||||
label_map_shape,
|
||||
label_num,
|
||||
min_area,
|
||||
text_line);
|
||||
|
||||
return text_line;
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_PLUGIN(pse) {
|
||||
py::module m("pse", "pse");
|
||||
|
||||
m.def("pse", &pse_adaptor::pse, "pse");
|
||||
|
||||
return m.ptr();
|
||||
}
|
|
@ -2,6 +2,7 @@ import cv2
|
|||
import numpy as np
|
||||
import pyclipper
|
||||
import torch
|
||||
from mmcv.ops import contour_expand, pixel_group
|
||||
from numpy.linalg import norm
|
||||
from shapely.geometry import Polygon
|
||||
from skimage.morphology import skeletonize
|
||||
|
@ -36,7 +37,6 @@ def pan_decode(preds,
|
|||
min_text_confidence=0.5,
|
||||
min_kernel_confidence=0.5,
|
||||
min_text_avg_confidence=0.85,
|
||||
min_kernel_area=0,
|
||||
min_text_area=16):
|
||||
"""Convert scores to quadrangles via post processing in PANet. This is
|
||||
partially adapted from https://github.com/WenmuZhou/PAN.pytorch.
|
||||
|
@ -47,13 +47,11 @@ def pan_decode(preds,
|
|||
min_text_confidence (float): The minimal text confidence.
|
||||
min_kernel_confidence (float): The minimal kernel confidence.
|
||||
min_text_avg_confidence (float): The minimal text average confidence.
|
||||
min_kernel_area (int): The minimal text kernel area.
|
||||
min_text_area (int): The minimal text instance region area.
|
||||
Returns:
|
||||
boundaries: (list[list[float]]): The instance boundary and its
|
||||
instance confidence list.
|
||||
"""
|
||||
from .pan import assign_pixels, estimate_text_confidence, get_pixel_num
|
||||
preds[:2, :, :] = torch.sigmoid(preds[:2, :, :])
|
||||
preds = preds.detach().cpu().numpy()
|
||||
|
||||
|
@ -64,28 +62,16 @@ def pan_decode(preds,
|
|||
|
||||
region_num, labels = cv2.connectedComponents(
|
||||
kernel.astype(np.uint8), connectivity=4)
|
||||
valid_kernel_inx = []
|
||||
region_pixel_num = get_pixel_num(labels, region_num)
|
||||
|
||||
# from inx 1. 0: meaningless.
|
||||
for region_idx in range(1, region_num):
|
||||
if region_pixel_num[region_idx] < min_kernel_area:
|
||||
continue
|
||||
valid_kernel_inx.append(region_idx)
|
||||
|
||||
# assign pixels to valid kernels
|
||||
assignment = assign_pixels(
|
||||
text.astype(np.uint8), embeddings, labels, region_num, 0.8)
|
||||
assignment = assignment.reshape(text.shape)
|
||||
contours, _ = cv2.findContours((kernel * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
|
||||
kernel_contours = np.zeros(text.shape, dtype='uint8')
|
||||
cv2.drawContours(kernel_contours, contours, -1, 255)
|
||||
text_points = pixel_group(text_score, text, embeddings, labels,
|
||||
kernel_contours, region_num,
|
||||
min_text_avg_confidence)
|
||||
|
||||
boundaries = []
|
||||
|
||||
# compute text avg confidence
|
||||
|
||||
text_points = estimate_text_confidence(assignment, text_score, region_num)
|
||||
for text_inx, text_point in text_points.items():
|
||||
if text_inx not in valid_kernel_inx:
|
||||
continue
|
||||
for text_inx, text_point in enumerate(text_points):
|
||||
text_confidence = text_point[0]
|
||||
text_point = text_point[2:]
|
||||
text_point = np.array(text_point, dtype=int).reshape(-1, 2)
|
||||
|
@ -133,17 +119,16 @@ def pse_decode(preds,
|
|||
score = score.data.cpu().numpy().astype(np.float32) # to numpy
|
||||
|
||||
kernel_masks = kernel_masks.data.cpu().numpy().astype(np.uint8) # to numpy
|
||||
from .pse import pse
|
||||
|
||||
region_num, labels = cv2.connectedComponents(
|
||||
kernel_masks[-1], connectivity=4)
|
||||
|
||||
# labels = pse(kernel_masks, min_kernel_area)
|
||||
labels = pse(kernel_masks, min_kernel_area, labels, region_num)
|
||||
labels = contour_expand(kernel_masks, labels, min_kernel_area, region_num)
|
||||
labels = np.array(labels)
|
||||
label_num = np.max(labels) + 1
|
||||
label_num = np.max(labels)
|
||||
boundaries = []
|
||||
for i in range(1, label_num):
|
||||
for i in range(1, label_num + 1):
|
||||
points = np.array(np.where(labels == i)).transpose((1, 0))[:, ::-1]
|
||||
area = points.shape[0]
|
||||
score_instance = np.mean(score[labels == i])
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
from .ops.rroi_align import RROIAlign
|
||||
|
||||
__all__ = ['RROIAlign']
|
|
@ -1,3 +0,0 @@
|
|||
from .rroi_align import RROIAlign
|
||||
|
||||
__all__ = ['RROIAlign']
|
|
@ -1,46 +0,0 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#pragma once
|
||||
|
||||
#include "cpu/vision.h"
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include "cuda/vision.h"
|
||||
#endif
|
||||
|
||||
// Interface for Python
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RROIAlign_forward(const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width) {
|
||||
if (input.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return RROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
//return RROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
|
||||
}
|
||||
|
||||
at::Tensor RROIAlign_backward(const at::Tensor& grad,
|
||||
const at::Tensor& rois,
|
||||
const at::Tensor& con_idx_x,
|
||||
const at::Tensor& con_idx_y,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int batch_size,
|
||||
const int channels,
|
||||
const int height,
|
||||
const int width) {
|
||||
if (grad.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return RROIAlign_backward_cuda(grad, rois, con_idx_x, con_idx_y, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU");
|
||||
}
|
|
@ -1,257 +0,0 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#include "vision.h"
|
||||
|
||||
// implementation taken from Caffe2
|
||||
template <typename T>
|
||||
struct PreCalc {
|
||||
int pos1;
|
||||
int pos2;
|
||||
int pos3;
|
||||
int pos4;
|
||||
T w1;
|
||||
T w2;
|
||||
T w3;
|
||||
T w4;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void pre_calc_for_bilinear_interpolate(
|
||||
const int height,
|
||||
const int width,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int iy_upper,
|
||||
const int ix_upper,
|
||||
T roi_start_h,
|
||||
T roi_start_w,
|
||||
T bin_size_h,
|
||||
T bin_size_w,
|
||||
int roi_bin_grid_h,
|
||||
int roi_bin_grid_w,
|
||||
std::vector<PreCalc<T>>& pre_calc) {
|
||||
int pre_calc_index = 0;
|
||||
for (int ph = 0; ph < pooled_height; ph++) {
|
||||
for (int pw = 0; pw < pooled_width; pw++) {
|
||||
for (int iy = 0; iy < iy_upper; iy++) {
|
||||
const T yy = roi_start_h + ph * bin_size_h +
|
||||
static_cast<T>(iy + .5f) * bin_size_h /
|
||||
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < ix_upper; ix++) {
|
||||
const T xx = roi_start_w + pw * bin_size_w +
|
||||
static_cast<T>(ix + .5f) * bin_size_w /
|
||||
static_cast<T>(roi_bin_grid_w);
|
||||
|
||||
T x = xx;
|
||||
T y = yy;
|
||||
// deal with: inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
||||
// empty
|
||||
PreCalc<T> pc;
|
||||
pc.pos1 = 0;
|
||||
pc.pos2 = 0;
|
||||
pc.pos3 = 0;
|
||||
pc.pos4 = 0;
|
||||
pc.w1 = 0;
|
||||
pc.w2 = 0;
|
||||
pc.w3 = 0;
|
||||
pc.w4 = 0;
|
||||
pre_calc[pre_calc_index] = pc;
|
||||
pre_calc_index += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (y <= 0) {
|
||||
y = 0;
|
||||
}
|
||||
if (x <= 0) {
|
||||
x = 0;
|
||||
}
|
||||
|
||||
int y_low = (int)y;
|
||||
int x_low = (int)x;
|
||||
int y_high;
|
||||
int x_high;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T)y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T)x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
// save weights and indeces
|
||||
PreCalc<T> pc;
|
||||
pc.pos1 = y_low * width + x_low;
|
||||
pc.pos2 = y_low * width + x_high;
|
||||
pc.pos3 = y_high * width + x_low;
|
||||
pc.pos4 = y_high * width + x_high;
|
||||
pc.w1 = w1;
|
||||
pc.w2 = w2;
|
||||
pc.w3 = w3;
|
||||
pc.w4 = w4;
|
||||
pre_calc[pre_calc_index] = pc;
|
||||
|
||||
pre_calc_index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ROIAlignForward_cpu_kernel(
|
||||
const int nthreads,
|
||||
const T* bottom_data,
|
||||
const T& spatial_scale,
|
||||
const int channels,
|
||||
const int height,
|
||||
const int width,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int sampling_ratio,
|
||||
const T* bottom_rois,
|
||||
//int roi_cols,
|
||||
T* top_data) {
|
||||
//AT_ASSERT(roi_cols == 4 || roi_cols == 5);
|
||||
int roi_cols = 5;
|
||||
|
||||
int n_rois = nthreads / channels / pooled_width / pooled_height;
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
// can be parallelized using omp
|
||||
// #pragma omp parallel for num_threads(32)
|
||||
for (int n = 0; n < n_rois; n++) {
|
||||
int index_n = n * channels * pooled_width * pooled_height;
|
||||
|
||||
// roi could have 4 or 5 columns
|
||||
const T* offset_bottom_rois = bottom_rois + n * roi_cols;
|
||||
int roi_batch_ind = 0;
|
||||
if (roi_cols == 5) {
|
||||
roi_batch_ind = offset_bottom_rois[0];
|
||||
offset_bottom_rois++;
|
||||
}
|
||||
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
T roi_start_w = offset_bottom_rois[0] * spatial_scale;
|
||||
T roi_start_h = offset_bottom_rois[1] * spatial_scale;
|
||||
T roi_end_w = offset_bottom_rois[2] * spatial_scale;
|
||||
T roi_end_h = offset_bottom_rois[3] * spatial_scale;
|
||||
// T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
|
||||
// T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
|
||||
// T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
|
||||
// T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
|
||||
|
||||
// Force malformed ROIs to be 1x1
|
||||
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
|
||||
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
|
||||
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
||||
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0)
|
||||
? sampling_ratio
|
||||
: ceil(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w =
|
||||
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
// we want to precalculate indeces and weights shared by all chanels,
|
||||
// this is the key point of optimiation
|
||||
std::vector<PreCalc<T>> pre_calc(
|
||||
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
|
||||
pre_calc_for_bilinear_interpolate(
|
||||
height,
|
||||
width,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
roi_bin_grid_h,
|
||||
roi_bin_grid_w,
|
||||
roi_start_h,
|
||||
roi_start_w,
|
||||
bin_size_h,
|
||||
bin_size_w,
|
||||
roi_bin_grid_h,
|
||||
roi_bin_grid_w,
|
||||
pre_calc);
|
||||
|
||||
for (int c = 0; c < channels; c++) {
|
||||
int index_n_c = index_n + c * pooled_width * pooled_height;
|
||||
const T* offset_bottom_data =
|
||||
bottom_data + (roi_batch_ind * channels + c) * height * width;
|
||||
int pre_calc_index = 0;
|
||||
|
||||
for (int ph = 0; ph < pooled_height; ph++) {
|
||||
for (int pw = 0; pw < pooled_width; pw++) {
|
||||
int index = index_n_c + ph * pooled_width + pw;
|
||||
|
||||
T output_val = 0.;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
|
||||
PreCalc<T> pc = pre_calc[pre_calc_index];
|
||||
output_val += pc.w1 * offset_bottom_data[pc.pos1] +
|
||||
pc.w2 * offset_bottom_data[pc.pos2] +
|
||||
pc.w3 * offset_bottom_data[pc.pos3] +
|
||||
pc.w4 * offset_bottom_data[pc.pos4];
|
||||
|
||||
pre_calc_index += 1;
|
||||
}
|
||||
}
|
||||
output_val /= count;
|
||||
|
||||
top_data[index] = output_val;
|
||||
} // for pw
|
||||
} // for ph
|
||||
} // for c
|
||||
} // for n
|
||||
}
|
||||
|
||||
at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int sampling_ratio) {
|
||||
AT_ASSERTM(!input.type().is_cuda(), "input must be a CPU tensor");
|
||||
AT_ASSERTM(!rois.type().is_cuda(), "rois must be a CPU tensor");
|
||||
|
||||
auto num_rois = rois.size(0);
|
||||
auto channels = input.size(1);
|
||||
auto height = input.size(2);
|
||||
auto width = input.size(3);
|
||||
|
||||
auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
|
||||
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
||||
|
||||
if (output.numel() == 0) {
|
||||
return output;
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
|
||||
ROIAlignForward_cpu_kernel<scalar_t>(
|
||||
output_size,
|
||||
input.data<scalar_t>(),
|
||||
spatial_scale,
|
||||
channels,
|
||||
height,
|
||||
width,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
sampling_ratio,
|
||||
rois.data<scalar_t>(),
|
||||
output.data<scalar_t>());
|
||||
});
|
||||
return output;
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
at::Tensor ROIAlign_forward_cpu(const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int sampling_ratio);
|
|
@ -1,346 +0,0 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
|
||||
// TODO make it in a common file
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
|
||||
template <typename T>
|
||||
__device__ T bilinear_interpolate(const T* bottom_data,
|
||||
const int height, const int width,
|
||||
T y, T x,
|
||||
const int index /* index for debug only*/) {
|
||||
|
||||
// deal with cases that inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
||||
//empty
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (y <= 0) y = 0;
|
||||
if (x <= 0) x = 0;
|
||||
|
||||
int y_low = (int) y;
|
||||
int x_low = (int) x;
|
||||
int y_high;
|
||||
int x_high;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T) y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T) x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
// do bilinear interpolation
|
||||
T v1 = bottom_data[y_low * width + x_low];
|
||||
T v2 = bottom_data[y_low * width + x_high];
|
||||
T v3 = bottom_data[y_high * width + x_low];
|
||||
T v4 = bottom_data[y_high * width + x_high];
|
||||
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
|
||||
const T spatial_scale, const int channels,
|
||||
const int height, const int width,
|
||||
const int pooled_height, const int pooled_width,
|
||||
const int sampling_ratio,
|
||||
const T* bottom_rois, T* top_data) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T* offset_bottom_rois = bottom_rois + n * 5;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
|
||||
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
|
||||
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
|
||||
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
|
||||
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
|
||||
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
|
||||
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
|
||||
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
|
||||
|
||||
// Force malformed ROIs to be 1x1
|
||||
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
|
||||
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
|
||||
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
||||
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
||||
|
||||
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
T output_val = 0.;
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
|
||||
{
|
||||
const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix ++)
|
||||
{
|
||||
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
|
||||
|
||||
T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
|
||||
output_val += val;
|
||||
}
|
||||
}
|
||||
output_val /= count;
|
||||
|
||||
top_data[index] = output_val;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
__device__ void bilinear_interpolate_gradient(
|
||||
const int height, const int width,
|
||||
T y, T x,
|
||||
T & w1, T & w2, T & w3, T & w4,
|
||||
int & x_low, int & x_high, int & y_low, int & y_high,
|
||||
const int index /* index for debug only*/) {
|
||||
|
||||
// deal with cases that inverse elements are out of feature map boundary
|
||||
if (y < -1.0 || y > height || x < -1.0 || x > width) {
|
||||
//empty
|
||||
w1 = w2 = w3 = w4 = 0.;
|
||||
x_low = x_high = y_low = y_high = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
if (y <= 0) y = 0;
|
||||
if (x <= 0) x = 0;
|
||||
|
||||
y_low = (int) y;
|
||||
x_low = (int) x;
|
||||
|
||||
if (y_low >= height - 1) {
|
||||
y_high = y_low = height - 1;
|
||||
y = (T) y_low;
|
||||
} else {
|
||||
y_high = y_low + 1;
|
||||
}
|
||||
|
||||
if (x_low >= width - 1) {
|
||||
x_high = x_low = width - 1;
|
||||
x = (T) x_low;
|
||||
} else {
|
||||
x_high = x_low + 1;
|
||||
}
|
||||
|
||||
T ly = y - y_low;
|
||||
T lx = x - x_low;
|
||||
T hy = 1. - ly, hx = 1. - lx;
|
||||
|
||||
// reference in forward
|
||||
// T v1 = bottom_data[y_low * width + x_low];
|
||||
// T v2 = bottom_data[y_low * width + x_high];
|
||||
// T v3 = bottom_data[y_high * width + x_low];
|
||||
// T v4 = bottom_data[y_high * width + x_high];
|
||||
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
|
||||
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
|
||||
const int num_rois, const T spatial_scale,
|
||||
const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width,
|
||||
const int sampling_ratio,
|
||||
T* bottom_diff,
|
||||
const T* bottom_rois) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T* offset_bottom_rois = bottom_rois + n * 5;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
|
||||
// Do not using rounding; this implementation detail is critical
|
||||
T roi_start_w = offset_bottom_rois[1] * spatial_scale;
|
||||
T roi_start_h = offset_bottom_rois[2] * spatial_scale;
|
||||
T roi_end_w = offset_bottom_rois[3] * spatial_scale;
|
||||
T roi_end_h = offset_bottom_rois[4] * spatial_scale;
|
||||
// T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
|
||||
// T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
|
||||
// T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
|
||||
// T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
|
||||
|
||||
// Force malformed ROIs to be 1x1
|
||||
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
|
||||
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
|
||||
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
|
||||
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
|
||||
|
||||
T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
int top_offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
const T* offset_top_diff = top_diff + top_offset;
|
||||
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
||||
for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
|
||||
{
|
||||
const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
|
||||
for (int ix = 0; ix < roi_bin_grid_w; ix ++)
|
||||
{
|
||||
const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
|
||||
|
||||
T w1, w2, w3, w4;
|
||||
int x_low, x_high, y_low, y_high;
|
||||
|
||||
bilinear_interpolate_gradient(height, width, y, x,
|
||||
w1, w2, w3, w4,
|
||||
x_low, x_high, y_low, y_high,
|
||||
index);
|
||||
|
||||
T g1 = top_diff_this_bin * w1 / count;
|
||||
T g2 = top_diff_this_bin * w2 / count;
|
||||
T g3 = top_diff_this_bin * w3 / count;
|
||||
T g4 = top_diff_this_bin * w4 / count;
|
||||
|
||||
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0)
|
||||
{
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast<T>(g1));
|
||||
atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast<T>(g2));
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast<T>(g3));
|
||||
atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast<T>(g4));
|
||||
} // if
|
||||
} // ix
|
||||
} // iy
|
||||
} // CUDA_1D_KERNEL_LOOP
|
||||
} // RoIAlignBackward
|
||||
|
||||
|
||||
at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int sampling_ratio) {
|
||||
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
|
||||
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
||||
|
||||
auto num_rois = rois.size(0);
|
||||
auto channels = input.size(1);
|
||||
auto height = input.size(2);
|
||||
auto width = input.size(3);
|
||||
|
||||
auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
|
||||
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
if (output.numel() == 0) {
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return output;
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIAlign_forward", [&] {
|
||||
RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
|
||||
output_size,
|
||||
input.contiguous().data<scalar_t>(),
|
||||
spatial_scale,
|
||||
channels,
|
||||
height,
|
||||
width,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
sampling_ratio,
|
||||
rois.contiguous().data<scalar_t>(),
|
||||
output.data<scalar_t>());
|
||||
});
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return output;
|
||||
}
|
||||
|
||||
// TODO remove the dependency on input and use instead its sizes -> save memory
|
||||
at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int batch_size,
|
||||
const int channels,
|
||||
const int height,
|
||||
const int width,
|
||||
const int sampling_ratio) {
|
||||
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
|
||||
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
||||
|
||||
auto num_rois = rois.size(0);
|
||||
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
// handle possibly empty gradients
|
||||
if (grad.numel() == 0) {
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIAlign_backward", [&] {
|
||||
RoIAlignBackwardFeature<scalar_t><<<grid, block, 0, stream>>>(
|
||||
grad.numel(),
|
||||
grad.contiguous().data<scalar_t>(),
|
||||
num_rois,
|
||||
spatial_scale,
|
||||
channels,
|
||||
height,
|
||||
width,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
sampling_ratio,
|
||||
grad_input.data<scalar_t>(),
|
||||
rois.contiguous().data<scalar_t>());
|
||||
});
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return grad_input;
|
||||
}
|
|
@ -1,202 +0,0 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
|
||||
|
||||
// TODO make it in a common file
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void RoIPoolFForward(const int nthreads, const T* bottom_data,
|
||||
const T spatial_scale, const int channels, const int height,
|
||||
const int width, const int pooled_height, const int pooled_width,
|
||||
const T* bottom_rois, T* top_data, int* argmax_data) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T* offset_bottom_rois = bottom_rois + n * 5;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
int roi_start_w = roundf(offset_bottom_rois[1] * spatial_scale);
|
||||
int roi_start_h = roundf(offset_bottom_rois[2] * spatial_scale);
|
||||
int roi_end_w = roundf(offset_bottom_rois[3] * spatial_scale);
|
||||
int roi_end_h = roundf(offset_bottom_rois[4] * spatial_scale);
|
||||
|
||||
// Force malformed ROIs to be 1x1
|
||||
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
|
||||
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
|
||||
T bin_size_h = static_cast<T>(roi_height)
|
||||
/ static_cast<T>(pooled_height);
|
||||
T bin_size_w = static_cast<T>(roi_width)
|
||||
/ static_cast<T>(pooled_width);
|
||||
|
||||
int hstart = static_cast<int>(floorf(static_cast<T>(ph)
|
||||
* bin_size_h));
|
||||
int wstart = static_cast<int>(floorf(static_cast<T>(pw)
|
||||
* bin_size_w));
|
||||
int hend = static_cast<int>(ceilf(static_cast<T>(ph + 1)
|
||||
* bin_size_h));
|
||||
int wend = static_cast<int>(ceilf(static_cast<T>(pw + 1)
|
||||
* bin_size_w));
|
||||
|
||||
// Add roi offsets and clip to input boundaries
|
||||
hstart = min(max(hstart + roi_start_h, 0), height);
|
||||
hend = min(max(hend + roi_start_h, 0), height);
|
||||
wstart = min(max(wstart + roi_start_w, 0), width);
|
||||
wend = min(max(wend + roi_start_w, 0), width);
|
||||
bool is_empty = (hend <= hstart) || (wend <= wstart);
|
||||
|
||||
// Define an empty pooling region to be zero
|
||||
T maxval = is_empty ? 0 : -FLT_MAX;
|
||||
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
|
||||
int maxidx = -1;
|
||||
const T* offset_bottom_data =
|
||||
bottom_data + (roi_batch_ind * channels + c) * height * width;
|
||||
for (int h = hstart; h < hend; ++h) {
|
||||
for (int w = wstart; w < wend; ++w) {
|
||||
int bottom_index = h * width + w;
|
||||
if (offset_bottom_data[bottom_index] > maxval) {
|
||||
maxval = offset_bottom_data[bottom_index];
|
||||
maxidx = bottom_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
top_data[index] = maxval;
|
||||
argmax_data[index] = maxidx;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RoIPoolFBackward(const int nthreads, const T* top_diff,
|
||||
const int* argmax_data, const int num_rois, const T spatial_scale,
|
||||
const int channels, const int height, const int width,
|
||||
const int pooled_height, const int pooled_width, T* bottom_diff,
|
||||
const T* bottom_rois) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int pw = index % pooled_width;
|
||||
int ph = (index / pooled_width) % pooled_height;
|
||||
int c = (index / pooled_width / pooled_height) % channels;
|
||||
int n = index / pooled_width / pooled_height / channels;
|
||||
|
||||
const T* offset_bottom_rois = bottom_rois + n * 5;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
int bottom_offset = (roi_batch_ind * channels + c) * height * width;
|
||||
int top_offset = (n * channels + c) * pooled_height * pooled_width;
|
||||
const T* offset_top_diff = top_diff + top_offset;
|
||||
T* offset_bottom_diff = bottom_diff + bottom_offset;
|
||||
const int* offset_argmax_data = argmax_data + top_offset;
|
||||
|
||||
int argmax = offset_argmax_data[ph * pooled_width + pw];
|
||||
if (argmax != -1) {
|
||||
atomicAdd(
|
||||
offset_bottom_diff + argmax,
|
||||
static_cast<T>(offset_top_diff[ph * pooled_width + pw]));
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width) {
|
||||
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
|
||||
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
||||
|
||||
auto num_rois = rois.size(0);
|
||||
auto channels = input.size(1);
|
||||
auto height = input.size(2);
|
||||
auto width = input.size(3);
|
||||
|
||||
auto output = at::empty({num_rois, channels, pooled_height, pooled_width}, input.options());
|
||||
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
||||
auto argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kInt));
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
if (output.numel() == 0) {
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return std::make_tuple(output, argmax);
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(input.type(), "ROIPool_forward", [&] {
|
||||
RoIPoolFForward<scalar_t><<<grid, block, 0, stream>>>(
|
||||
output_size,
|
||||
input.contiguous().data<scalar_t>(),
|
||||
spatial_scale,
|
||||
channels,
|
||||
height,
|
||||
width,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
rois.contiguous().data<scalar_t>(),
|
||||
output.data<scalar_t>(),
|
||||
argmax.data<int>());
|
||||
});
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return std::make_tuple(output, argmax);
|
||||
}
|
||||
|
||||
// TODO remove the dependency on input and use instead its sizes -> save memory
|
||||
at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const at::Tensor& argmax,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int batch_size,
|
||||
const int channels,
|
||||
const int height,
|
||||
const int width) {
|
||||
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
|
||||
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
||||
// TODO add more checks
|
||||
|
||||
auto num_rois = rois.size(0);
|
||||
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
// handle possibly empty gradients
|
||||
if (grad.numel() == 0) {
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(grad.type(), "ROIPool_backward", [&] {
|
||||
RoIPoolFBackward<scalar_t><<<grid, block, 0, stream>>>(
|
||||
grad.numel(),
|
||||
grad.contiguous().data<scalar_t>(),
|
||||
argmax.data<int>(),
|
||||
num_rois,
|
||||
spatial_scale,
|
||||
channels,
|
||||
height,
|
||||
width,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
grad_input.data<scalar_t>(),
|
||||
rois.contiguous().data<scalar_t>());
|
||||
});
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return grad_input;
|
||||
}
|
|
@ -1,375 +0,0 @@
|
|||
// Copyright (c) Jianqi Ma. All Rights Reserved.
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <THC/THC.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
//#include "rroi_alignment_kernel.h"
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename T>
|
||||
__global__ void RROIAlignForward(
|
||||
const int nthreads,
|
||||
const T* bottom_data,
|
||||
const T spatial_scale,
|
||||
int height,
|
||||
int width,
|
||||
int channels,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const T* bottom_rois,
|
||||
T* top_data,
|
||||
float* con_idx_x,
|
||||
float* con_idx_y)
|
||||
{
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads)
|
||||
{
|
||||
// +0.5 shift removed
|
||||
int imageWidth = width;
|
||||
int imageHeight = height;
|
||||
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int n = index;
|
||||
int pw = n % pooled_width;
|
||||
n /= pooled_width;
|
||||
int ph = n % pooled_height;
|
||||
n /= pooled_height;
|
||||
int c = n % channels;
|
||||
n /= channels;
|
||||
|
||||
const T* offset_bottom_rois = bottom_rois + n * 7; //= 7 is rois dimension 0
|
||||
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
T cx = offset_bottom_rois[1];
|
||||
T cy = offset_bottom_rois[2];
|
||||
T h = offset_bottom_rois[3];
|
||||
T w = offset_bottom_rois[4];
|
||||
//T angle = offset_bottom_rois[5]/180.0*3.1415926535;
|
||||
T Alpha = offset_bottom_rois[5];
|
||||
T Beta = offset_bottom_rois[6];
|
||||
|
||||
//TransformPrepare
|
||||
T dx = -pooled_width/2.0;
|
||||
T dy = -pooled_height/2.0;
|
||||
T Sx = w*spatial_scale/pooled_width;
|
||||
T Sy = h*spatial_scale/pooled_height;
|
||||
//T Alpha = cos(angle);
|
||||
//T Beta = sin(angle);
|
||||
T Dx = cx*spatial_scale;
|
||||
T Dy = cy*spatial_scale;
|
||||
|
||||
T M[2][3];
|
||||
M[0][0] = Alpha*Sx;
|
||||
M[0][1] = Beta*Sy;
|
||||
M[0][2] = Alpha*Sx*dx+Beta*Sy*dy+Dx;
|
||||
M[1][0] = -Beta*Sx;
|
||||
M[1][1] = Alpha*Sy;
|
||||
M[1][2] = -Beta*Sx*dx+Alpha*Sy*dy+Dy;
|
||||
|
||||
T P[8];
|
||||
P[0] = M[0][0]*pw+M[0][1]*ph+M[0][2];
|
||||
P[1] = M[1][0]*pw+M[1][1]*ph+M[1][2];
|
||||
P[2] = M[0][0]*pw+M[0][1]*(ph+1)+M[0][2];
|
||||
P[3] = M[1][0]*pw+M[1][1]*(ph+1)+M[1][2];
|
||||
P[4] = M[0][0]*(pw+1)+M[0][1]*ph+M[0][2];
|
||||
P[5] = M[1][0]*(pw+1)+M[1][1]*ph+M[1][2];
|
||||
P[6] = M[0][0]*(pw+1)+M[0][1]*(ph+1)+M[0][2];
|
||||
P[7] = M[1][0]*(pw+1)+M[1][1]*(ph+1)+M[1][2];
|
||||
|
||||
T leftMost = (max(roundf(min(min(P[0],P[2]),min(P[4],P[6]))),0.0));
|
||||
T rightMost= (min(roundf(max(max(P[0],P[2]),max(P[4],P[6]))),imageWidth-1.0));
|
||||
T topMost= (max(roundf(min(min(P[1],P[3]),min(P[5],P[7]))),0.0));
|
||||
T bottomMost= (min(roundf(max(max(P[1],P[3]),max(P[5],P[7]))),imageHeight-1.0));
|
||||
|
||||
//float maxval = 0;
|
||||
//int maxidx = -1;
|
||||
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
//float AB[2];
|
||||
//AB[0] = P[2] - P[0];
|
||||
//AB[1] = P[3] - P[1];
|
||||
//float ABAB = AB[0]*AB[0] +AB[1]*AB[1];
|
||||
//float AC[2];
|
||||
//AC[0] = P[4] - P[0];
|
||||
//AC[1] = P[5] - P[1];
|
||||
//float ACAC = AC[0]*AC[0] + AC[1]*AC[1];
|
||||
|
||||
float bin_cx = (leftMost + rightMost) / 2.0; // shift
|
||||
float bin_cy = (topMost + bottomMost) / 2.0;
|
||||
|
||||
int bin_l = (int)floorf(bin_cx);
|
||||
int bin_r = (int)ceilf(bin_cx);
|
||||
int bin_t = (int)floorf(bin_cy);
|
||||
int bin_b = (int)ceilf(bin_cy);
|
||||
|
||||
T lt_value = 0.0;
|
||||
if (bin_t > 0 && bin_l > 0 && bin_t < height && bin_l < width)
|
||||
lt_value = offset_bottom_data[bin_t * width + bin_l];
|
||||
T rt_value = 0.0;
|
||||
if (bin_t > 0 && bin_r > 0 && bin_t < height && bin_r < width)
|
||||
rt_value = offset_bottom_data[bin_t * width + bin_r];
|
||||
T lb_value = 0.0;
|
||||
if (bin_b > 0 && bin_l > 0 && bin_b < height && bin_l < width)
|
||||
lb_value = offset_bottom_data[bin_b * width + bin_l];
|
||||
T rb_value = 0.0;
|
||||
if (bin_b > 0 && bin_r > 0 && bin_b < height && bin_r < width)
|
||||
rb_value = offset_bottom_data[bin_b * width + bin_r];
|
||||
|
||||
T rx = bin_cx - floorf(bin_cx);
|
||||
T ry = bin_cy - floorf(bin_cy);
|
||||
|
||||
T wlt = (1.0 - rx) * (1.0 - ry);
|
||||
T wrt = rx * (1.0 - ry);
|
||||
T wrb = rx * ry;
|
||||
T wlb = (1.0 - rx) * ry;
|
||||
|
||||
T inter_val = 0.0;
|
||||
|
||||
inter_val += lt_value * wlt;
|
||||
inter_val += rt_value * wrt;
|
||||
inter_val += rb_value * wrb;
|
||||
inter_val += lb_value * wlb;
|
||||
|
||||
atomicAdd(top_data + index, static_cast<T>(inter_val));
|
||||
atomicAdd(con_idx_x + index, static_cast<float>(bin_cx));
|
||||
atomicAdd(con_idx_y + index, static_cast<float>(bin_cy));
|
||||
|
||||
//top_data[index] = static_cast<T>(inter_val);
|
||||
//con_idx_x[index] = bin_cx;
|
||||
//con_idx_y[index] = bin_cy;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
int RROIAlignForwardLaucher(
|
||||
const float* bottom_data, const float spatial_scale, const int num_rois, const int height,
|
||||
const int width, const int channels, const int pooled_height,
|
||||
const int pooled_width, const float* bottom_rois,
|
||||
float* top_data, float* con_idx_x, float* con_idx_y, const float* im_info, cudaStream_t stream)
|
||||
{
|
||||
const int kThreadsPerBlock = 1024;
|
||||
const int output_size = num_rois * pooled_height * pooled_width * channels;
|
||||
cudaError_t err;
|
||||
|
||||
|
||||
RROIAlignForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
|
||||
output_size, bottom_data, spatial_scale, height, width, channels, pooled_height,
|
||||
pooled_width, bottom_rois, top_data, con_idx_x, con_idx_y, im_info);
|
||||
|
||||
err = cudaGetLastError();
|
||||
if(cudaSuccess != err)
|
||||
{
|
||||
fprintf( stderr, "RRoI forward cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
||||
exit( -1 );
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
*/
|
||||
//ROIAlign_forward_cuda
|
||||
|
||||
//std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
|
||||
// const at::Tensor& rois,
|
||||
// const float spatial_scale,
|
||||
// const int pooled_height,
|
||||
// const int pooled_width)
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RROIAlign_forward_cuda(
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width)
|
||||
{
|
||||
|
||||
|
||||
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
|
||||
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
||||
|
||||
auto num_rois = rois.size(0);
|
||||
auto channels = input.size(1);
|
||||
auto height = input.size(2);
|
||||
auto width = input.size(3);
|
||||
|
||||
auto output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
|
||||
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
||||
auto con_idx_x = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kFloat));
|
||||
auto con_idx_y = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kFloat));
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
//const int kThreadsPerBlock = 1024;
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (output.numel() == 0) {
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return std::make_tuple(output, con_idx_x, con_idx_y);//, con_idx_y; //std::make_tuple(
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(input.type(), "RROIAlign_forward", [&] {
|
||||
RROIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
|
||||
output_size,
|
||||
input.contiguous().data<scalar_t>(),
|
||||
spatial_scale,
|
||||
height,
|
||||
width,
|
||||
channels,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
rois.contiguous().data<scalar_t>(),
|
||||
output.data<scalar_t>(),
|
||||
con_idx_x.data<float>(),
|
||||
con_idx_y.data<float>());
|
||||
}
|
||||
);
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return std::make_tuple(output, con_idx_x, con_idx_y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RROIAlignBackward(
|
||||
const int nthreads,
|
||||
const T* top_diff,
|
||||
const float* con_idx_x,
|
||||
const float* con_idx_y,
|
||||
const int num_rois,
|
||||
const float spatial_scale,
|
||||
const int height,
|
||||
const int width,
|
||||
const int channels,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
T* bottom_diff,
|
||||
const T* bottom_rois) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads)
|
||||
{
|
||||
|
||||
// (n, c, ph, pw) is an element in the pooled output
|
||||
int n = index;
|
||||
//int w = n % width;
|
||||
n /= pooled_width;
|
||||
//int h = n % height;
|
||||
n /= pooled_height;
|
||||
int c = n % channels;
|
||||
n /= channels;
|
||||
|
||||
const T* offset_bottom_rois = bottom_rois + n * 7;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
//int bottom_index = argmax_data[index];
|
||||
|
||||
float bw = con_idx_x[index];
|
||||
float bh = con_idx_y[index];
|
||||
//if (bh > 0.00001 && bw > 0.00001 && bw < height-1 && bw < width-1){
|
||||
|
||||
int bin_xs = int(floorf(bw));
|
||||
int bin_ys = int(floorf(bh));
|
||||
|
||||
float rx = bw - float(bin_xs);
|
||||
float ry = bh - float(bin_ys);
|
||||
|
||||
T wlt = (1.0 - rx) * (1.0 - ry);
|
||||
T wrt = rx * (1.0 - ry);
|
||||
T wrb = rx * ry;
|
||||
T wlb = (1.0 - rx) * ry;
|
||||
|
||||
// if(bottom_index >= 0) // original != 0 maybe wrong
|
||||
// bottom_diff[bottom_index]+=top_diff[index] ;
|
||||
|
||||
//int min_x = bin_xs, 0), width - 1);
|
||||
//int min_y = min(max(bin_ys, 0), height - 1);
|
||||
//int max_x = max(min(bin_xs + 1, width - 1), 0);
|
||||
//int max_y = max(min(bin_ys + 1, height - 1), 0);
|
||||
|
||||
int min_x = (int)floorf(bw);
|
||||
int max_x = (int)ceilf(bw);
|
||||
int min_y = (int)floorf(bh);
|
||||
int max_y = (int)ceilf(bh);
|
||||
|
||||
T top_diff_of_bin = top_diff[index];
|
||||
|
||||
T v1 = wlt * top_diff_of_bin;
|
||||
T v2 = wrt * top_diff_of_bin;
|
||||
T v3 = wrb * top_diff_of_bin;
|
||||
T v4 = wlb * top_diff_of_bin;
|
||||
|
||||
// Atomic add
|
||||
|
||||
if (min_y > 0 && min_x > 0 && min_y < height - 1 && min_x < width - 1)
|
||||
atomicAdd(offset_bottom_diff + min_y * width + min_x, static_cast<T>(v1));
|
||||
if (min_y > 0 && max_x < width - 1 && min_y < height - 1 && max_x > 0)
|
||||
atomicAdd(offset_bottom_diff + min_y * width + max_x, static_cast<T>(v2));
|
||||
if (max_y < height - 1 && max_x < width - 1 && max_y > 0 && max_x > 0)
|
||||
atomicAdd(offset_bottom_diff + max_y * width + max_x, static_cast<T>(v3));
|
||||
if (max_y < height - 1 && min_x > 0 && max_y > 0 && min_x < width - 1)
|
||||
atomicAdd(offset_bottom_diff + max_y * width + min_x, static_cast<T>(v4));
|
||||
|
||||
//}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO remove the dependency on input and use instead its sizes -> save memory
|
||||
at::Tensor RROIAlign_backward_cuda(const at::Tensor& grad,
|
||||
const at::Tensor& rois,
|
||||
const at::Tensor& con_idx_x,
|
||||
const at::Tensor& con_idx_y,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int batch_size,
|
||||
const int channels,
|
||||
const int height,
|
||||
const int width) {
|
||||
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
|
||||
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
|
||||
// TODO add more checks
|
||||
|
||||
auto num_rois = rois.size(0);
|
||||
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
// handle possibly empty gradients
|
||||
if (grad.numel() == 0) {
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(grad.type(), "RROIAlign_backward", [&] {
|
||||
RROIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
|
||||
grad.numel(),
|
||||
grad.contiguous().data<scalar_t>(),
|
||||
con_idx_x.data<float>(),
|
||||
con_idx_y.data<float>(),
|
||||
num_rois,
|
||||
spatial_scale,
|
||||
height,
|
||||
width,
|
||||
channels,
|
||||
pooled_height,
|
||||
pooled_width,
|
||||
grad_input.data<scalar_t>(),
|
||||
rois.contiguous().data<scalar_t>());
|
||||
});
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return grad_input;
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> RROIAlign_forward_cuda(const at::Tensor& input,
|
||||
const at::Tensor& rois,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width);
|
||||
|
||||
at::Tensor RROIAlign_backward_cuda(const at::Tensor& grad,
|
||||
const at::Tensor& rois,
|
||||
const at::Tensor& con_idx_x,
|
||||
const at::Tensor& con_idx_y,
|
||||
const float spatial_scale,
|
||||
const int pooled_height,
|
||||
const int pooled_width,
|
||||
const int batch_size,
|
||||
const int channels,
|
||||
const int height,
|
||||
const int width);
|
||||
|
||||
at::Tensor compute_flow_cuda(const at::Tensor& boxes,
|
||||
const int height,
|
||||
const int width);
|
|
@ -1,10 +0,0 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
#include "RROIAlign.h"
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
m.def("rroi_align_forward", &RROIAlign_forward, "RROIAlign_forward");
|
||||
m.def("rroi_align_backward", &RROIAlign_backward, "RROIAlign_backward");
|
||||
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from .csrc import rroi_align_backward, rroi_align_forward
|
||||
|
||||
|
||||
class _RROIAlign(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, roi, output_size, spatial_scale):
|
||||
|
||||
ctx.output_size = _pair(output_size)
|
||||
ctx.spatial_scale = spatial_scale
|
||||
ctx.input_shape = input.size()
|
||||
total_output = rroi_align_forward(input, roi, spatial_scale,
|
||||
output_size[0], output_size[1])
|
||||
|
||||
output, con_idx_x, con_idx_y = total_output
|
||||
ctx.save_for_backward(roi, con_idx_x, con_idx_y)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
rois, con_idx_x, con_idx_y = ctx.saved_tensors
|
||||
output_size = ctx.output_size
|
||||
spatial_scale = ctx.spatial_scale
|
||||
bs, ch, h, w = ctx.input_shape
|
||||
grad_input = rroi_align_backward(grad_output, rois, con_idx_x,
|
||||
con_idx_y, spatial_scale,
|
||||
output_size[0], output_size[1], bs,
|
||||
ch, h, w)
|
||||
return grad_input, None, None, None
|
||||
|
||||
|
||||
rroi_align = _RROIAlign.apply
|
||||
|
||||
|
||||
class RROIAlign(nn.Module):
|
||||
|
||||
def __init__(self, output_size, spatial_scale):
|
||||
super(RROIAlign, self).__init__()
|
||||
self.output_size = output_size
|
||||
self.spatial_scale = spatial_scale
|
||||
|
||||
def forward(self, input, rois):
|
||||
return rroi_align(input, rois, self.output_size, self.spatial_scale)
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = self.__class__.__name__ + '('
|
||||
tmpstr += 'output_size=' + str(self.output_size)
|
||||
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
|
||||
tmpstr += ')'
|
||||
return tmpstr
|
|
@ -1,2 +1,2 @@
|
|||
mmcv-full>=1.3.1
|
||||
mmcv-full>=1.3.4
|
||||
mmdet>=2.11.0
|
||||
|
|
54
setup.py
54
setup.py
|
@ -1,12 +1,7 @@
|
|||
import glob
|
||||
import os
|
||||
import sys
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
|
||||
CUDAExtension)
|
||||
|
||||
|
||||
def readme():
|
||||
with open('README.md', encoding='utf-8') as f:
|
||||
|
@ -103,47 +98,11 @@ def parse_requirements(fname='requirements.txt', with_version=True):
|
|||
return packages
|
||||
|
||||
|
||||
def get_rroi_align_extensions():
|
||||
extensions_dir = 'mmocr/models/utils/ops/rroi_align/csrc/csc'
|
||||
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
|
||||
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
|
||||
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
|
||||
sources = main_file + source_cpu
|
||||
extension = CppExtension
|
||||
extra_compile_args = {'cxx': ['/utf-8']} if is_windows else {'cxx': []}
|
||||
define_macros = []
|
||||
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
extension = CUDAExtension
|
||||
sources += source_cuda
|
||||
define_macros += [('WITH_CUDA', None)]
|
||||
extra_compile_args['nvcc'] = [
|
||||
'-DCUDA_HAS_FP16=1',
|
||||
'-D__CUDA_NO_HALF_OPERATORS__',
|
||||
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-D__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
print(sources)
|
||||
include_dirs = [extensions_dir]
|
||||
print('include_dirs', include_dirs, flush=True)
|
||||
ext = extension(
|
||||
name='mmocr.models.utils.ops.rroi_align.csrc',
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
define_macros=define_macros,
|
||||
extra_compile_args=extra_compile_args,
|
||||
)
|
||||
|
||||
return ext
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
library_dirs = [
|
||||
lp for lp in os.environ.get('LD_LIBRARY_PATH', '').split(':')
|
||||
if len(lp) > 1
|
||||
]
|
||||
cpp_root = 'mmocr/models/textdet/postprocess/'
|
||||
setup(
|
||||
name='mmocr',
|
||||
version=get_version(),
|
||||
|
@ -156,7 +115,6 @@ if __name__ == '__main__':
|
|||
packages=find_packages(exclude=('configs', 'tools', 'demo')),
|
||||
include_package_data=True,
|
||||
url='https://github.com/open-mmlab/mmocr',
|
||||
package_data={'mmocr.ops': ['*/*.so']},
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
|
@ -176,16 +134,4 @@ if __name__ == '__main__':
|
|||
'build': parse_requirements('requirements/build.txt'),
|
||||
'optional': parse_requirements('requirements/optional.txt'),
|
||||
},
|
||||
ext_modules=[
|
||||
CppExtension(
|
||||
name='mmocr.models.textdet.postprocess.pan',
|
||||
sources=[cpp_root + 'pan.cpp'],
|
||||
extra_compile_args=(['/utf-8'] if is_windows else [])),
|
||||
CppExtension(
|
||||
name='mmocr.models.textdet.postprocess.pse',
|
||||
sources=[cpp_root + 'pse.cpp'],
|
||||
extra_compile_args=(['/utf-8'] if is_windows else [])),
|
||||
get_rroi_align_extensions()
|
||||
],
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
zip_safe=False)
|
||||
|
|
Loading…
Reference in New Issue