/*
 * Copyright (c) 2023-2024 Ian Marco Moffett and the Osmora Team.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 * 3. Neither the name of Hyra nor the names of its
 *    contributors may be used to endorse or promote products derived from
 *    this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <sys/types.h>
#include <sys/param.h>
#include <sys/driver.h>
#include <sys/errno.h>
#include <sys/sched.h>
#include <sys/syslog.h>
#include <sys/mmio.h>
#include <dev/ic/nvmeregs.h>
#include <dev/ic/nvmevar.h>
#include <dev/pci/pci.h>
#include <dev/pci/pciregs.h>
#include <dev/timer.h>
#include <vm/dynalloc.h>
#include <vm/vm.h>
#include <string.h>

#define pr_trace(fmt, ...) kprintf("nvme: " fmt, ##__VA_ARGS__)
#define pr_error(...) pr_trace(__VA_ARGS__)

static struct pci_device *nvme_dev;
static struct timer tmr;

static inline int
is_4k_aligned(void *ptr)
{
    return ((uintptr_t)ptr & (0x1000 - 1)) == 0;
}

/*
 * Poll register to have 'bits' set/unset.
 *
 * @reg: Register to poll.
 * @bits: Bits to be checked.
 * @pollset: True to poll as set.
 */
static int
nvme_poll_reg(struct nvme_bar *bar, volatile uint32_t *reg, uint32_t bits,
              bool pollset)
{
    size_t usec_start, usec;
    size_t elapsed_msec;
    uint32_t val, caps;
    bool tmp;

    usec_start = tmr.get_time_usec();
    caps = mmio_read32(&bar->caps);

    for (;;) {
        val = mmio_read32(reg);
        tmp = (pollset) ? ISSET(val, bits) : !ISSET(val, bits);

        usec = tmr.get_time_usec();
        elapsed_msec = (usec - usec_start) / 1000;

        /* If tmp is set, the register updated in time */
        if (tmp) {
            break;
        }

        /* Exit with an error if we timeout */
        if (elapsed_msec > CAP_TIMEOUT(caps)) {
            return -ETIME;
        }
    }

    return val;
}

static int
nvme_create_queue(struct nvme_bar *bar, struct nvme_queue *queue, size_t id)
{
    uint8_t dbstride;
    uint16_t slots;
    uint64_t  caps;
    uintptr_t sq_db, cq_db;

    caps = mmio_read32(&bar->caps);
    dbstride = CAP_STRIDE(caps);
    slots = CAP_MQES(caps);

    queue->sq = dynalloc_memalign(sizeof(void *) * slots, 0x1000);
    queue->cq = dynalloc_memalign(sizeof(void *) * slots, 0x1000);

    if (queue->sq == NULL) {
        return -ENOMEM;
    }

    if (queue->cq == NULL) {
        dynfree(queue->sq);
        return -ENOMEM;
    }

    memset(queue->sq, 0, sizeof(void *) * slots);
    memset(queue->cq, 0, sizeof(void *) * slots);

    sq_db = (uintptr_t)bar + DEFAULT_PAGESIZE + (2 * id * (4 << dbstride));
    cq_db = (uintptr_t)bar + DEFAULT_PAGESIZE + ((2 * id + 1) * (4 << dbstride));

    queue->sq_head = 0;
    queue->sq_tail = 0;

    queue->size = slots;
    queue->cq_phase = 1;
    queue->sq_db = (void *)sq_db;
    queue->cq_db = (void *)cq_db;
    return 0;
}

/*
 * Stop and reset the NVMe controller.
 */
static int
nvme_stop_ctrl(struct nvme_bar *bar)
{
    uint32_t config, status;

    /* Do not reset if CSTS.RDY is 0 */
    status = mmio_read32(&bar->status);
    if (!ISSET(status, STATUS_RDY)) {
        return 0;
    }

    /* Clear the enable bit to begin the reset */
    config = mmio_read32(&bar->config);
    config &= ~CONFIG_EN;
    mmio_write32(&bar->config, config);

    if (nvme_poll_reg(bar, &bar->status, STATUS_RDY, false) < 0) {
        pr_error("Controller reset timeout\n");
        return -ETIME;
    }

    return 0;
}

/*
 * Start up the controller.
 */
static int
nvme_start_ctrl(struct nvme_bar *bar)
{
    uint32_t config, status;

    /* Cannot start if already started */
    status = mmio_read32(&bar->status);
    if (ISSET(status, STATUS_RDY)) {
        return 0;
    }

    /* Enable the controller */
    config = mmio_read32(&bar->config);
    config |= CONFIG_EN;
    mmio_write32(&bar->config, config);

    if (nvme_poll_reg(bar, &bar->status, STATUS_RDY, true) < 0) {
        pr_error("Controller startup timeout\n");
        return -ETIME;
    }

    return 0;
}

/*
 * Submit a command.
 */
static void
nvme_submit_cmd(struct nvme_queue *q, struct nvme_cmd cmd)
{
    q->sq[q->sq_tail++] = cmd;
    if (q->sq_tail >= q->size) {
        q->sq_tail = 0;
    }

    mmio_write32(q->sq_db, q->sq_tail);
}

/*
 * Submit a command and poll for completion.
 */
static int
nvme_poll_submit_cmd(struct nvme_queue *q, struct nvme_cmd cmd)
{
    uint16_t status;
    uint8_t spins = 0;

    nvme_submit_cmd(q, cmd);

    for (;;) {
        /*
         * If the phase bit matches the most recently submitted
         * command then the command has completed
         */
        status = q->cq[q->cq_head].status;
        if ((status & 1) == q->cq_phase) {
            break;
        }

        /* Are any error bits set? */
        if ((status & ~1) != 0) {
            pr_trace("Command error (bits=0x%x)\n", status >> 1);
            return -EIO;
        }

        /* Check for timeout */
        if (spins > 5) {
            pr_error("Hang while polling phase bit, giving up\n");
            return -ETIME;
        }

        tmr.msleep(150);
        ++spins;
    }

    return 0;
}

static int
nvme_identify(struct nvme_ctrl *ctrl, void *buf, uint32_t nsid, uint8_t cns)
{
    struct nvme_cmd cmd = {0};
    struct nvme_identify_cmd *idcmd = &cmd.identify;

    if (!is_4k_aligned(buf)) {
        return -1;
    }

    idcmd->opcode = NVME_OP_IDENTIFY;
    idcmd->nsid = nsid;
    idcmd->cns = cns;  /* Identify controller */
    idcmd->prp1 = VIRT_TO_PHYS(buf);
    idcmd->prp2 = 0;
    return nvme_poll_submit_cmd(&ctrl->adminq, cmd);
}

/*
 * For debugging purposes, logs some information
 * found within the controller identify data structure.
 */
static void
nvme_log_ctrl_id(struct nvme_id *id)
{
    char mn[41] = {0};
    char sn[21] = {0};
    char fr[9] = {0};

    for (size_t i = 0; i < sizeof(id->mn); ++i) {
        mn[i] = id->mn[i];
    }

    for (size_t i = 0; i < sizeof(id->fr); ++i) {
        fr[i] = id->fr[i];
    }

    for (size_t i = 0; i < sizeof(id->sn); ++i) {
        sn[i] = id->sn[i];
    }

    pr_trace("Model number: %s\n", mn);
    pr_trace("Serial number: %s\n", sn);
    pr_trace("Firmware revision: %s\n", fr);
}

/*
 * Init PCI related controller bits
 */
static void
nvme_init_pci(void)
{
    uint32_t tmp;

    /* Enable bus mastering and MMIO */
    tmp = pci_readl(nvme_dev, PCIREG_CMDSTATUS);
    tmp |= (PCI_BUS_MASTERING | PCI_MEM_SPACE);
    pci_writel(nvme_dev, PCIREG_CMDSTATUS, tmp);
}

static int
nvme_init_ctrl(struct nvme_bar *bar)
{
    int error;
    uint64_t caps;
    uint16_t mqes;
    struct nvme_ctrl ctrl = {0};
    struct nvme_queue *adminq;
    struct nvme_id *id;

    /* Ensure the controller is stopped */
    if ((error = nvme_stop_ctrl(bar)) != 0) {
        return error;
    }

    adminq = &ctrl.adminq;
    caps = mmio_read64(&bar->caps);
    mqes = CAP_MQES(caps);

    /* Setup admin queues */
    nvme_create_queue(bar, adminq, 0);
    mmio_write32(&bar->aqa, (mqes | mqes << 16));
    mmio_write64(&bar->asq, VIRT_TO_PHYS(adminq->sq));
    mmio_write64(&bar->acq, VIRT_TO_PHYS(adminq->cq));

    /* Now bring the controller back up */
    if ((error = nvme_start_ctrl(bar)) != 0) {
        return error;
    }

    id = dynalloc_memalign(sizeof(*id), 0x1000);
    if (id == NULL) {
        return -ENOMEM;
    }

    nvme_identify(&ctrl, id, 0, ID_CNS_CTRL);
    nvme_log_ctrl_id(id);

    dynfree(id);
    return 0;
}

static int
nvme_init(void)
{
    struct pci_lookup lookup;
    struct nvme_bar *bar;
    int error;

    lookup.pci_class = 1;
    lookup.pci_subclass = 8;
    nvme_dev = pci_get_device(lookup, PCI_CLASS | PCI_SUBCLASS);

    if (nvme_dev == NULL) {
        return -ENODEV;
    }

    /* Try to request a general purpose timer */
    if (req_timer(TIMER_GP, &tmr) != TMRR_SUCCESS) {
        pr_error("Failed to fetch general purpose timer\n");
        return -ENODEV;
    }

    /* Ensure it has get_time_usec() */
    if (tmr.get_time_usec == NULL) {
        pr_error("General purpose timer has no get_time_usec()\n");
        return -ENODEV;
    }

    /* We also need msleep() */
    if (tmr.msleep == NULL) {
        pr_error("General purpose timer has no msleep()\n");
        return -ENODEV;
    }

    nvme_init_pci();

    if ((error = pci_map_bar(nvme_dev, 0, (void *)&bar)) != 0) {
        return error;
    }

    return nvme_init_ctrl(bar);
}

DRIVER_EXPORT(nvme_init);