2020-03-14 09:56:04 +08:00
|
|
|
from simclr import SimCLR
|
|
|
|
import yaml
|
2020-03-14 18:01:49 +08:00
|
|
|
from data_aug.dataset_wrapper import DataSetWrapper
|
2020-03-14 09:56:04 +08:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
|
2020-03-14 18:01:49 +08:00
|
|
|
dataset = DataSetWrapper(config['batch_size'], **config['dataset'])
|
2020-03-14 09:56:04 +08:00
|
|
|
|
2020-03-14 18:01:49 +08:00
|
|
|
simclr = SimCLR(dataset, config)
|
2020-03-14 09:56:04 +08:00
|
|
|
simclr.train()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|