2021-01-30 15:40:05 +08:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
#include "NvInfer.h"
|
2021-02-13 21:14:14 +08:00
|
|
|
#include "fastrt/module.h"
|
2021-01-30 15:40:05 +08:00
|
|
|
#include "fastrt/struct.h"
|
2021-02-13 21:28:29 +08:00
|
|
|
#include "fastrt/factory.h"
|
2021-01-30 15:40:05 +08:00
|
|
|
using namespace nvinfer1;
|
|
|
|
|
|
|
|
namespace fastrt {
|
|
|
|
|
2021-02-13 21:14:14 +08:00
|
|
|
class embedding_head : public Module {
|
2021-02-13 21:28:29 +08:00
|
|
|
private:
|
2021-02-27 16:40:04 +08:00
|
|
|
FastreidConfig& _modelCfg;
|
2021-02-13 21:28:29 +08:00
|
|
|
std::unique_ptr<LayerFactory> _layerFactory;
|
|
|
|
|
2021-02-13 21:14:14 +08:00
|
|
|
public:
|
2021-02-27 16:40:04 +08:00
|
|
|
embedding_head(FastreidConfig& modelCfg);
|
|
|
|
embedding_head(FastreidConfig& modelCfg, std::unique_ptr<LayerFactory> layerFactory);
|
2021-02-13 21:14:14 +08:00
|
|
|
~embedding_head() = default;
|
2021-01-30 15:40:05 +08:00
|
|
|
|
2021-02-13 21:14:14 +08:00
|
|
|
ILayer* topology(INetworkDefinition *network,
|
|
|
|
std::map<std::string, Weights>& weightMap,
|
2021-02-27 16:40:04 +08:00
|
|
|
ITensor& input) override;
|
2021-02-13 21:14:14 +08:00
|
|
|
};
|
2021-02-13 21:28:29 +08:00
|
|
|
|
2021-01-30 15:40:05 +08:00
|
|
|
}
|