SimCLR/run.py

16 lines
346 B
Python
Raw Normal View History

2020-03-14 09:56:04 +08:00
from simclr import SimCLR
import yaml
2020-03-14 18:01:49 +08:00
from data_aug.dataset_wrapper import DataSetWrapper
2020-03-14 09:56:04 +08:00
def main():
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
2020-03-14 18:01:49 +08:00
dataset = DataSetWrapper(config['batch_size'], **config['dataset'])
2020-03-14 09:56:04 +08:00
2020-03-14 18:01:49 +08:00
simclr = SimCLR(dataset, config)
2020-03-14 09:56:04 +08:00
simclr.train()
if __name__ == "__main__":
main()