/**
 * Copyright (c) 2015-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the BSD+Patents license found in the
 * LICENSE file in the root directory of this source tree.
 */


#include "OnDiskInvertedLists.h"

#include <pthread.h>

#include <unordered_set>

#include <sys/mman.h>
#include <unistd.h>
#include <sys/types.h>

#include "FaissAssert.h"


namespace faiss {


/**********************************************
 * LockLevels
 **********************************************/


struct LockLevels {
    /* There n times lock1(n), one lock2 and one lock3
     * Invariants:
     *    a single thread can hold one lock1(n) for some n
     *    a single thread can hold lock2, if it holds lock1(n) for some n
     *    a single thread can hold lock3, if it holds lock1(n) for some n
     *       AND lock2 AND no other thread holds lock1(m) for m != n
     */
    pthread_mutex_t mutex1;
    pthread_cond_t level1_cv;
    pthread_cond_t level2_cv;
    pthread_cond_t level3_cv;

    std::unordered_set<int> level1_holders; // which level1 locks are held
    int n_level2; // nb threads that wait on level2
    bool level3_in_use; // a threads waits on level3
    bool level2_in_use;

    LockLevels() {
        pthread_mutex_init(&mutex1, nullptr);
        pthread_cond_init(&level1_cv, nullptr);
        pthread_cond_init(&level2_cv, nullptr);
        pthread_cond_init(&level3_cv, nullptr);
        n_level2 = 0;
        level2_in_use = false;
        level3_in_use = false;
    }

    ~LockLevels() {
        pthread_cond_destroy(&level1_cv);
        pthread_cond_destroy(&level2_cv);
        pthread_cond_destroy(&level3_cv);
        pthread_mutex_destroy(&mutex1);
    }

    void lock_1(int no) {
        pthread_mutex_lock(&mutex1);
        while (level3_in_use || level1_holders.count(no) > 0) {
            pthread_cond_wait(&level1_cv, &mutex1);
        }
        level1_holders.insert(no);
        pthread_mutex_unlock(&mutex1);
    }

    void unlock_1(int no) {
        pthread_mutex_lock(&mutex1);
        assert(level1_holders.count(no) == 1);
        level1_holders.erase(no);
        if (level3_in_use) { // a writer is waiting
            pthread_cond_signal(&level3_cv);
        } else {
            pthread_cond_broadcast(&level1_cv);
        }
        pthread_mutex_unlock(&mutex1);
    }

    void lock_2() {
        pthread_mutex_lock(&mutex1);
        n_level2 ++;
        if (level3_in_use) { // tell waiting level3 that we are blocked
            pthread_cond_signal(&level3_cv);
        }
        while (level2_in_use) {
            pthread_cond_wait(&level2_cv, &mutex1);
        }
        level2_in_use = true;
        pthread_mutex_unlock(&mutex1);
    }

    void unlock_2() {
        pthread_mutex_lock(&mutex1);
        level2_in_use = false;
        n_level2 --;
        pthread_cond_signal(&level2_cv);
        pthread_mutex_unlock(&mutex1);
    }

    void lock_3() {
        pthread_mutex_lock(&mutex1);
        level3_in_use = true;
        // wait until there are no level1 holders anymore except the
        // ones that are waiting on level2 (we are holding lock2)
        while (level1_holders.size() > n_level2) {
            pthread_cond_wait(&level3_cv, &mutex1);
        }
        // don't release the lock!
    }

    void unlock_3() {
        level3_in_use = false;
        // wake up all level1_holders
        pthread_cond_broadcast(&level1_cv);
        pthread_mutex_unlock(&mutex1);
    }

    void print () {
        pthread_mutex_lock(&mutex1);
        printf("State: level3_in_use=%d n_level2=%d level1_holders: [", level3_in_use, n_level2);
        for (int k : level1_holders) {
            printf("%d ", k);
        }
        printf("]\n");
        pthread_mutex_unlock(&mutex1);
    }

};

/**********************************************
 * OngoingPrefetch
 **********************************************/

struct OnDiskInvertedLists::OngoingPrefetch {

    struct Thread {
        pthread_t pth;
        const OnDiskInvertedLists *od;
        int64_t list_no;
    };

    std::vector<Thread> threads;

    pthread_mutex_t mutex;

    // pretext to avoid code below to be optimized out
    static int global_cs;

    const OnDiskInvertedLists *od;

    OngoingPrefetch (const OnDiskInvertedLists *od): od (od)
    {
        pthread_mutex_init (&mutex, nullptr);
    }

    static void* prefetch_list (void * arg) {
        Thread *th = static_cast<Thread*>(arg);

        size_t n = th->od->list_size(th->list_no);
        const Index::idx_t *idx = th->od->get_ids(th->list_no);
        const uint8_t *codes = th->od->get_codes(th->list_no);
        int cs = 0;
        for (size_t i = 0; i < n;i++) {
            cs += idx[i];
        }
        const long *codes8 = (const long*)codes;
        long n8 = n * th->od->code_size / 8;

        for (size_t i = 0; i < n8;i++) {
            cs += codes8[i];
        }
        global_cs += cs & 1;
        return nullptr;
    }

    void prefetch_lists (const long *list_nos, int n) {
        pthread_mutex_lock (&mutex);
        for (auto &th: threads) {
            if (th.list_no != -1) {
                pthread_join (th.pth, nullptr);
            }
        }
        threads.resize (n);
        for (int i = 0; i < n; i++) {
            long list_no = list_nos[i];
            Thread & th = threads[i];
            if (list_no >= 0 && od->list_size(list_no) > 0) {
                th.list_no = list_no;
                th.od = od;
                pthread_create (&th.pth, nullptr, prefetch_list, &th);
            } else {
                th.list_no = -1;
            }
        }
        pthread_mutex_unlock (&mutex);
    }

    ~OngoingPrefetch () {
        pthread_mutex_lock (&mutex);
        for (auto &th: threads) {
            if (th.list_no != -1) {
                pthread_join (th.pth, nullptr);
            }
        }
        pthread_mutex_unlock (&mutex);
        pthread_mutex_destroy (&mutex);
    }

};

int OnDiskInvertedLists::OngoingPrefetch::global_cs = 0;


void OnDiskInvertedLists::prefetch_lists (const long *list_nos, int n) const
{
    pf->prefetch_lists (list_nos, n);
}



/**********************************************
 * OnDiskInvertedLists: mmapping
 **********************************************/


void OnDiskInvertedLists::do_mmap ()
{
    const char *rw_flags = read_only ? "r" : "rw+";
    int prot = read_only ? PROT_READ : PROT_WRITE | PROT_READ;
    FILE *f = fopen (filename.c_str(), rw_flags);
    FAISS_THROW_IF_NOT_FMT (f, "could not open %s in mode %s: %s",
                            filename.c_str(), rw_flags, strerror(errno));

    ptr = (uint8_t*)mmap (nullptr, totsize,
                          prot, MAP_SHARED, fileno (f), 0);

    FAISS_THROW_IF_NOT_FMT (ptr != MAP_FAILED,
                            "could not mmap %s: %s",
                            filename.c_str(),
                            strerror(errno));
    fclose (f);

}

void OnDiskInvertedLists::update_totsize (size_t new_size)
{

    // unmap file
    if (ptr != nullptr) {
        int err = munmap (ptr, totsize);
        FAISS_THROW_IF_NOT_FMT (err == 0, "mumap error: %s",
                                strerror(errno));
    }
    if (totsize == 0) {
        // must create file before truncating it
        FILE *f = fopen (filename.c_str(), "w");
        FAISS_THROW_IF_NOT_FMT (f, "could not open %s in mode W: %s",
                                filename.c_str(), strerror(errno));
        fclose (f);
    }

    if (new_size > totsize) {
        if (!slots.empty() &&
            slots.back().offset + slots.back().capacity == totsize) {
            slots.back().capacity += new_size - totsize;
        } else {
            slots.push_back (Slot(totsize, new_size - totsize));
        }
    } else {
        assert(!"not implemented");
    }

    totsize = new_size;

    // create file
    printf ("resizing %s to %ld bytes\n", filename.c_str(), totsize);

    int err = truncate (filename.c_str(), totsize);

    FAISS_THROW_IF_NOT_FMT (err == 0, "truncate %s to %ld: %s",
                            filename.c_str(), totsize,
                            strerror(errno));
    do_mmap ();
}






/**********************************************
 * OnDiskInvertedLists
 **********************************************/

#define INVALID_OFFSET (size_t)(-1)

OnDiskInvertedLists::List::List ():
    size (0), capacity (0), offset (INVALID_OFFSET)
{}

OnDiskInvertedLists::Slot::Slot (size_t offset, size_t capacity):
    offset (offset), capacity (capacity)
{}

OnDiskInvertedLists::Slot::Slot ():
    offset (0), capacity (0)
{}



OnDiskInvertedLists::OnDiskInvertedLists (
        size_t nlist, size_t code_size,
        const char *filename):
    InvertedLists (nlist, code_size),
    filename (filename),
    totsize (0),
    ptr (nullptr),
    read_only (false),
    locks (new LockLevels ()),
    pf (new OngoingPrefetch (this))
{
    lists.resize (nlist);

    // slots starts empty
}

OnDiskInvertedLists::OnDiskInvertedLists ():
    InvertedLists (0, 0),
    read_only (false),
    locks (new LockLevels ()),
    pf (new OngoingPrefetch (this))
{
}

OnDiskInvertedLists::~OnDiskInvertedLists ()
{
    delete pf;

    // unmap all lists
    if (ptr != nullptr) {
        int err = munmap (ptr, totsize);
        FAISS_THROW_IF_NOT_FMT (err == 0,
                                "mumap error: %s",
                                strerror(errno));
    }
    delete locks;
}




size_t OnDiskInvertedLists::list_size(size_t list_no) const
{
    return lists[list_no].size;
}


const uint8_t * OnDiskInvertedLists::get_codes (size_t list_no) const
{
    return ptr + lists[list_no].offset;
}

const Index::idx_t * OnDiskInvertedLists::get_ids (size_t list_no) const
{
    return (const idx_t*)(ptr + lists[list_no].offset +
                          code_size * lists[list_no].capacity);
}


void OnDiskInvertedLists::update_entries (
      size_t list_no, size_t offset, size_t n_entry,
      const idx_t *ids_in, const uint8_t *codes_in)
{
    FAISS_THROW_IF_NOT (!read_only);
    if (n_entry == 0) return;
    const List & l = lists[list_no];
    assert (n_entry + offset <= l.size);
    idx_t *ids = const_cast<idx_t*>(get_ids (list_no));
    memcpy (ids + offset, ids_in, sizeof(ids_in[0]) * n_entry);
    uint8_t *codes = const_cast<uint8_t*>(get_codes (list_no));
    memcpy (codes + offset * code_size, codes_in, code_size * n_entry);
}

size_t OnDiskInvertedLists::add_entries (
           size_t list_no, size_t n_entry,
           const idx_t* ids, const uint8_t *code)
{
    FAISS_THROW_IF_NOT (!read_only);
    locks->lock_1 (list_no);
    size_t o = list_size (list_no);
    resize_locked (list_no, n_entry + o);
    update_entries (list_no, o, n_entry, ids, code);
    locks->unlock_1 (list_no);
    return o;
}

void OnDiskInvertedLists::resize (size_t list_no, size_t new_size)
{
    FAISS_THROW_IF_NOT (!read_only);
    locks->lock_1 (list_no);
    resize_locked (list_no, new_size);
    locks->unlock_1 (list_no);
}



void OnDiskInvertedLists::resize_locked (size_t list_no, size_t new_size)
{
    List & l = lists[list_no];

    if (new_size <= l.capacity &&
        new_size > l.capacity / 2) {
        l.size = new_size;
        return;
    }

    // otherwise we release the current slot, and find a new one

    locks->lock_2 ();
    free_slot (l.offset, l.capacity);

    List new_l;

    if (new_size == 0) {
        new_l = List();
    } else {
        new_l.size = new_size;
        new_l.capacity = 1;
        while (new_l.capacity < new_size) {
            new_l.capacity *= 2;
        }
        new_l.offset = allocate_slot (
            new_l.capacity * (sizeof(idx_t) + code_size));
    }

    // copy common data
    if (l.offset != new_l.offset) {
        size_t n = std::min (new_size, l.size);
        if (n > 0) {
            memcpy (ptr + new_l.offset, get_codes(list_no), n * code_size);
            memcpy (ptr + new_l.offset + new_l.capacity * code_size,
                    get_ids (list_no), n * sizeof(idx_t));
        }
    }

    lists[list_no] = new_l;
    locks->unlock_2 ();
}

size_t OnDiskInvertedLists::allocate_slot (size_t capacity) {
    // should hold lock2

    auto it = slots.begin();
    while (it != slots.end() && it->capacity < capacity) {
        it++;
    }

    if (it == slots.end()) {
        // not enough capacity
        size_t new_size = totsize == 0 ? 32 : totsize * 2;
        while (new_size - totsize < capacity)
            new_size *= 2;
        locks->lock_3 ();
        update_totsize(new_size);
        locks->unlock_3 ();
        it = slots.begin();
        while (it != slots.end() && it->capacity < capacity) {
            it++;
        }
        assert (it != slots.end());
    }

    size_t o = it->offset;
    if (it->capacity == capacity) {
        slots.erase (it);
    } else {
        // take from beginning of slot
        it->capacity -= capacity;
        it->offset += capacity;
    }

    return o;
}



void OnDiskInvertedLists::free_slot (size_t offset, size_t capacity) {

    // should hold lock2
    if (capacity == 0) return;

    auto it = slots.begin();
    while (it != slots.end() && it->offset <= offset) {
        it++;
    }

    size_t inf = 1UL << 60;

    size_t end_prev = inf;
    if (it != slots.begin()) {
        auto prev = it;
        prev--;
        end_prev = prev->offset + prev->capacity;
    }

    size_t begin_next = 1L << 60;
    if (it != slots.end()) {
        begin_next = it->offset;
    }

    assert (end_prev == inf || offset >= end_prev);
    assert (offset + capacity <= begin_next);

    if (offset == end_prev) {
        auto prev = it;
        prev--;
        if (offset + capacity == begin_next) {
            prev->capacity += capacity + it->capacity;
            slots.erase (it);
        } else {
            prev->capacity += capacity;
        }
    } else {
        if (offset + capacity == begin_next) {
            it->offset -= capacity;
            it->capacity += capacity;
        } else {
            slots.insert (it, Slot (offset, capacity));
        }
    }

    // TODO shrink global storage if needed
}


/*****************************************
 * Compact form
 *****************************************/

size_t OnDiskInvertedLists::merge_from (const InvertedLists **ils, int n_il)
{
    FAISS_THROW_IF_NOT_MSG (totsize == 0, "works only on an empty InvertedLists");

    std::vector<size_t> sizes (nlist);
    for (int i = 0; i < n_il; i++) {
        const InvertedLists *il = ils[i];
        FAISS_THROW_IF_NOT (il->nlist == nlist && il->code_size == code_size);

        for (size_t j = 0; j < nlist; j++)  {
            sizes [j] += il->list_size(j);
        }
    }

    size_t cums = 0;
    size_t ntotal = 0;
    for (size_t j = 0; j < nlist; j++)  {
        ntotal += sizes[j];
        lists[j].size = 0;
        lists[j].capacity = sizes[j];
        lists[j].offset = cums;
        cums += lists[j].capacity * (sizeof(idx_t) + code_size);
    }

    update_totsize (cums);

#pragma omp parallel for
    for (size_t j = 0; j < nlist; j++) {
        List & l = lists[j];
        for (int i = 0; i < n_il; i++) {
            const InvertedLists *il = ils[i];
            size_t n_entry = il->list_size(j);
            l.size += n_entry;
            update_entries (j, l.size - n_entry, n_entry,
                            il->get_ids(j),
                            il->get_codes(j));
        }
        assert (l.size == l.capacity);
    }

    return ntotal;
}






} // namespace faiss