/*
 * Copyright (c) 2010-2025 Antmicro
 *
 * This file is licensed under the MIT License.
 */

#include <errno.h>
#include <stdint.h>
#include <string.h>
#include <sys/ioctl.h>

#include "cpu.h"
#include "cpu_registers.h"
#include "registers.h"
#include "utils.h"
#include "unwind.h"
#ifdef TARGET_X86KVM
#include "x86_reports.h"
#endif

static struct kvm_regs *get_regs()
{
    if(cpu->regs_state == CLEAR) {
        if(ioctl(cpu->vcpu_fd, KVM_GET_REGS, &cpu->regs) < 0) {
            kvm_abortf("KVM_GET_REGS: %s", strerror(errno));
        }
        if(!cpu->is_executing) {
            cpu->regs_state = PRESENT;
        }
    }
    return &cpu->regs;
}

static void set_regs(struct kvm_regs *regs)
{
    if(regs != &cpu->regs) {
        cpu->regs = *regs;
    }

    if(ioctl(cpu->vcpu_fd, KVM_SET_REGS, regs) < 0) {
        kvm_abortf("KVM_SET_REGS: %s", strerror(errno));
    }
    cpu->regs_state = PRESENT;
}

static struct kvm_sregs *get_sregs()
{
    if(cpu->sregs_state == CLEAR) {
        if(ioctl(cpu->vcpu_fd, KVM_GET_SREGS, &cpu->sregs) < 0) {
            kvm_abortf("KVM_GET_SREGS: %s", strerror(errno));
        }
        if(!cpu->is_executing) {
            cpu->sregs_state = PRESENT;
        }
    }
    return &cpu->sregs;
}

static void set_sregs(struct kvm_sregs *sregs)
{
    if(sregs != &cpu->sregs) {
        cpu->sregs = *sregs;
    }

    if(ioctl(cpu->vcpu_fd, KVM_SET_SREGS, sregs) < 0) {
        kvm_abortf("KVM_SET_SREGS: %s", strerror(errno));
    }
    cpu->sregs_state = PRESENT;
}

void kvm_registers_synchronize()
{
    if(cpu->regs_state == DIRTY) {
        set_regs(&cpu->regs);
    }
    if(cpu->sregs_state == DIRTY) {
        set_sregs(&cpu->sregs);
    }
}

void kvm_registers_invalidate()
{
    cpu->regs_state = cpu->sregs_state = CLEAR;
}

uint64_t *get_reg_pointer(struct kvm_regs *regs, int reg)
{
    switch(reg) {
        case RAX:
            return (uint64_t *)&(regs->rax);
        case RCX:
            return (uint64_t *)&(regs->rcx);
        case RDX:
            return (uint64_t *)&(regs->rdx);
        case RBX:
            return (uint64_t *)&(regs->rbx);
        case RSP:
            return (uint64_t *)&(regs->rsp);
        case RBP:
            return (uint64_t *)&(regs->rbp);
        case RSI:
            return (uint64_t *)&(regs->rsi);
        case RDI:
            return (uint64_t *)&(regs->rdi);
        case RIP:
            return (uint64_t *)&(regs->rip);
        case EFLAGS:
            return (uint64_t *)&(regs->rflags);
#ifdef TARGET_X86_64KVM
        case R8:
            return (uint64_t *)&(regs->r8);
        case R9:
            return (uint64_t *)&(regs->r9);
        case R10:
            return (uint64_t *)&(regs->r10);
        case R11:
            return (uint64_t *)&(regs->r11);
        case R12:
            return (uint64_t *)&(regs->r12);
        case R13:
            return (uint64_t *)&(regs->r13);
        case R14:
            return (uint64_t *)&(regs->r14);
        case R15:
            return (uint64_t *)&(regs->r15);
#endif

        default:
            return NULL;
    }
}

uint64_t *get_sreg_pointer(struct kvm_sregs *sregs, int reg)
{
    switch(reg) {
        case CS:
            return (uint64_t *)&(sregs->cs.base);
        case SS:
            return (uint64_t *)&(sregs->ss.base);
        case DS:
            return (uint64_t *)&(sregs->ds.base);
        case ES:
            return (uint64_t *)&(sregs->es.base);
        case FS:
            return (uint64_t *)&(sregs->fs.base);
        case GS:
            return (uint64_t *)&(sregs->gs.base);

        case CR0:
            return (uint64_t *)&(sregs->cr0);
        case CR1:
            return (uint64_t *)&(sregs->cr0);
        case CR2:
            return (uint64_t *)&(sregs->cr2);
        case CR3:
            return (uint64_t *)&(sregs->cr3);
        case CR4:
            return (uint64_t *)&(sregs->cr4);
#ifdef TARGET_X86_64KVM
        case CR8:
            return (uint64_t *)&(sregs->cr8);
        case EFER:
            return (uint64_t *)&(sregs->efer);
#endif

        default:
            return NULL;
    }
}

static bool is_executing_thread()
{
    return gettid() == cpu->tid;
}

static bool is_special_register(int reg_number)
{
    return reg_number >= CS;
}

#define EXPAND_ARGUMENTS(macro, ...) macro(__VA_ARGS__)
#ifdef TARGET_X86_64KVM
#define kvm_get_register_value kvm_get_register_value_64
#define kvm_set_register_value kvm_set_register_value_64
#else
#define kvm_get_register_value kvm_get_register_value_32
#define kvm_set_register_value kvm_set_register_value_32
#endif

reg_t kvm_get_register_value(int reg_number)
{
    uint64_t *ptr = NULL;

    if(cpu->is_executing && !is_executing_thread()) {
        kvm_logf(LOG_LEVEL_WARNING, "Register values are undefined when machine is running");
    }

    if(is_special_register(reg_number)) {
        struct kvm_sregs *sregs = get_sregs();
        ptr = get_sreg_pointer(sregs, reg_number);
    } else {
        struct kvm_regs *regs = get_regs();
        ptr = get_reg_pointer(regs, reg_number);
    }

    if(ptr == NULL) {
        if(ST0 <= reg_number && reg_number <= ST7) {
            kvm_logf(LOG_LEVEL_INFO, "Reading from STX registers is not implemented");
            return 0;
        }
        kvm_runtime_abortf("Read from undefined CPU register number %d detected", reg_number);
    }

#ifdef TARGET_X86KVM
    if(*ptr > UINT32_MAX) {
        handle_64bit_register_value(reg_number, *ptr);
    }
#endif

    return *ptr;
}
EXPAND_ARGUMENTS(EXC_INT_1, reg_t, kvm_get_register_value, int, reg_number)

reg_t get_register_value(Registers reg_number)
{
    return kvm_get_register_value(reg_number);
}

void kvm_set_register_value(int reg_number, reg_t value)
{
    if(cpu->is_executing) {
        kvm_logf(LOG_LEVEL_ERROR, "Cannot set register values when simulation is running");
        return;
    }

    uint64_t *ptr = NULL;

    if(is_special_register(reg_number)) {
        struct kvm_sregs *sregs = get_sregs();
        ptr = get_sreg_pointer(sregs, reg_number);
    } else {
        struct kvm_regs *regs = get_regs();
        ptr = get_reg_pointer(regs, reg_number);
    }

    if(ptr == NULL) {
        if(ST0 <= reg_number && reg_number <= ST7) {
            kvm_logf(LOG_LEVEL_INFO, "Writing to STX registers is not implemented");
            return;
        }
        kvm_runtime_abortf("Write to undefined CPU register number %d detected", reg_number);
    }

    *ptr = value;

    if(is_special_register(reg_number)) {
        cpu->sregs_state = DIRTY;
    } else {
        cpu->regs_state = DIRTY;
    }
}
EXPAND_ARGUMENTS(EXC_VOID_2, kvm_set_register_value, int, reg_number, reg_t, value)

void set_register_value(Registers reg_number, reg_t value)
{
    kvm_set_register_value(reg_number, value);
}

#define GET_FIELD(val, offset, width) ((uint8_t)(((val) >> (offset)) & (0xff >> (8 - (width)))))

#define SECTOR_DESCRIPTOR_SETTER(name)                                                                 \
    void kvm_set_##name##_descriptor(uint64_t base, uint32_t limit, uint16_t selector, uint32_t flags) \
    {                                                                                                  \
        if(cpu->is_executing) {                                                                        \
            kvm_logf(LOG_LEVEL_ERROR, "Cannot set register values when simulation is running");        \
            return;                                                                                    \
        }                                                                                              \
                                                                                                       \
        struct kvm_sregs *sregs = get_sregs();                                                         \
                                                                                                       \
        sregs->name.base = base;                                                                       \
        sregs->name.limit = limit;                                                                     \
        sregs->name.selector = selector;                                                               \
        sregs->name.type = GET_FIELD(flags, 8, 4);                                                     \
        sregs->name.present = GET_FIELD(flags, 15, 1);                                                 \
        sregs->name.dpl = GET_FIELD(flags, 13, 2);                                                     \
        sregs->name.db = GET_FIELD(flags, 22, 1);                                                      \
        sregs->name.s = GET_FIELD(flags, 12, 1);                                                       \
        sregs->name.l = GET_FIELD(flags, 21, 1);                                                       \
        sregs->name.g = GET_FIELD(flags, 23, 1);                                                       \
        sregs->name.avl = GET_FIELD(flags, 20, 1);                                                     \
                                                                                                       \
        cpu->sregs_state = DIRTY;                                                                      \
    }                                                                                                  \
                                                                                                       \
    EXC_VOID_4(kvm_set_##name##_descriptor, uint64_t, base, uint32_t, limit, uint16_t, selector, uint32_t, flags)

/* Segment descriptor setters
 * For more info plase refer to Intel(R) 64 and IA-32 Architectures Software Developer’s Manual Volume 3 (3.4.3) */
SECTOR_DESCRIPTOR_SETTER(cs)
SECTOR_DESCRIPTOR_SETTER(ds)
SECTOR_DESCRIPTOR_SETTER(es)
SECTOR_DESCRIPTOR_SETTER(ss)
SECTOR_DESCRIPTOR_SETTER(fs)
SECTOR_DESCRIPTOR_SETTER(gs)
