Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 112 additions & 8 deletions scripts/generate_public_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def _write_generated_header(include_root, devices):
default_device_type = _DEVICE_TYPES[default_device]
includes = [
"#include <cstddef>",
"#include <cstdint>",
"#include <type_traits>",
f"#include {_detail_include('data_type.h')}",
f"#include {_detail_include('device.h')}",
Expand Down Expand Up @@ -206,6 +207,8 @@ def _write_generated_header(include_root, devices):

using Stream = typename generated_detail::DefaultErrorRuntime::Stream;

using Event = void*;

using MemcpyKind = std::remove_cv_t<
decltype(generated_detail::DefaultErrorRuntime::kMemcpyHostToHost)>;

Expand Down Expand Up @@ -262,16 +265,32 @@ def params_decl(self):
"Malloc",
(_Param("void**", "ptr"), _Param("std::size_t", "size")),
),
_Function("Error", "Free", (_Param("void*", "ptr"),)),
_Function(
"Error",
"Memset",
"MallocHost",
(_Param("void**", "ptr"), _Param("std::size_t", "size")),
),
_Function(
"Error",
"MallocAsync",
(
_Param("void*", "ptr"),
_Param("int", "value"),
_Param("std::size_t", "count"),
_Param("void**", "ptr"),
_Param("std::size_t", "size"),
_Param("Stream", "stream"),
),
),
_Function("Error", "Free", (_Param("void*", "ptr"),)),
_Function("Error", "FreeHost", (_Param("void*", "ptr"),)),
_Function(
"Error",
"FreeAsync",
(_Param("void*", "ptr"), _Param("Stream", "stream")),
),
_Function(
"Error",
"MemGetInfo",
(_Param("std::size_t*", "free"), _Param("std::size_t*", "total")),
),
_Function(
"Error",
"Memcpy",
Expand All @@ -293,6 +312,60 @@ def params_decl(self):
_Param("Stream", "stream"),
),
),
_Function(
"Error",
"Memset",
(
_Param("void*", "ptr"),
_Param("int", "value"),
_Param("std::size_t", "count"),
),
),
_Function(
"Error",
"MemsetAsync",
(
_Param("void*", "ptr"),
_Param("int", "value"),
_Param("std::size_t", "count"),
_Param("Stream", "stream"),
),
),
_Function("Error", "StreamCreate", (_Param("Stream*", "stream"),)),
_Function("Error", "StreamDestroy", (_Param("Stream", "stream"),)),
_Function("Error", "StreamSynchronize", (_Param("Stream", "stream"),)),
_Function(
"Error",
"StreamWaitEvent",
(
_Param("Stream", "stream"),
_Param("Event", "event"),
_Param("unsigned int", "flags"),
),
),
_Function("Error", "EventCreate", (_Param("Event*", "event"),)),
_Function(
"Error",
"EventCreateWithFlags",
(_Param("Event*", "event"), _Param("unsigned int", "flags")),
),
_Function(
"Error",
"EventRecord",
(_Param("Event", "event"), _Param("Stream", "stream")),
),
_Function("Error", "EventQuery", (_Param("Event", "event"),)),
_Function("Error", "EventSynchronize", (_Param("Event", "event"),)),
_Function("Error", "EventDestroy", (_Param("Event", "event"),)),
_Function(
"Error",
"EventElapsedTime",
(
_Param("float*", "ms"),
_Param("Event", "start"),
_Param("Event", "end"),
),
),
)


Expand All @@ -312,6 +385,16 @@ def _runtime_arg(param, device):
return (
f"reinterpret_cast<typename Runtime<{device_type}>::Stream>({param.name})"
)
if param.type == "Stream*":
return (
f"reinterpret_cast<typename Runtime<{device_type}>::Stream*>({param.name})"
)
if param.type == "Event":
return f"reinterpret_cast<typename Runtime<{device_type}>::Event>({param.name})"
if param.type == "Event*":
return (
f"reinterpret_cast<typename Runtime<{device_type}>::Event*>({param.name})"
)

return param.name

Expand Down Expand Up @@ -351,10 +434,31 @@ def _write_runtime_dispatch_function(function, devices):
"""


def _write_runtime_dispatch(source_path, devices):
def _runtime_header_for_device(source_root, device):
for _, header_name, target in _DEVICE_HEADERS[device]:
if header_name == "runtime_.h":
return source_root / target

raise ValueError(f"device {device!r} does not have a runtime header")


def _devices_for_function(function, devices, source_root):
pattern = re.compile(r"\b" + re.escape(function.name) + r"\b")

return tuple(
device
for device in devices
if pattern.search(_runtime_header_for_device(source_root, device).read_text())
)


def _write_runtime_dispatch(source_path, source_root, devices):
functions = _PUBLIC_RUNTIME_FUNCTIONS
dispatch_functions = "\n".join(
_write_runtime_dispatch_function(function, devices=devices)
_write_runtime_dispatch_function(
function,
devices=_devices_for_function(function, devices, source_root),
)
for function in functions
)
set_device_type_cases = "\n".join(
Expand Down Expand Up @@ -463,7 +567,7 @@ def main():
_write_wrapper(include_root, wrapper_device, header_name, target)

_write_generated_header(include_root, devices)
_write_runtime_dispatch(pathlib.Path(args.source_output), devices)
_write_runtime_dispatch(pathlib.Path(args.source_output), source_root, devices)


if __name__ == "__main__":
Expand Down
124 changes: 122 additions & 2 deletions src/native/cpu/runtime_.h
Comment thread
voltjia marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#ifndef INFINI_RT_CPU_RUNTIME__H_
#define INFINI_RT_CPU_RUNTIME__H_

#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <new>

#include "runtime.h"

Expand All @@ -16,6 +19,8 @@ struct Runtime<Device::Type::kCpu> : RuntimeBase<Runtime<Device::Type::kCpu>> {

using Stream = void*;

using Event = void*;

static constexpr Error kSuccess = 0;

static constexpr Error kErrorInvalidValue = 1;
Expand Down Expand Up @@ -66,12 +71,53 @@ struct Runtime<Device::Type::kCpu> : RuntimeBase<Runtime<Device::Type::kCpu>> {
return kSuccess;
}

static Error MallocHost(void** ptr, std::size_t size) {
return Malloc(ptr, size);
}

static Error MallocAsync(void** ptr, std::size_t size, Stream) {
return kErrorInvalidValue;
}

static Error Free(void* ptr) {
std::free(ptr);

return kSuccess;
}

static Error FreeHost(void* ptr) { return Free(ptr); }

static Error FreeAsync(void* ptr, Stream) { return kErrorInvalidValue; }

static Error MemGetInfo(std::size_t* free, std::size_t* total) {
if (free == nullptr || total == nullptr) {
return kErrorInvalidValue;
}

*free = 0;
*total = 0;

#ifndef _WIN32
FILE* fp = std::fopen("/proc/meminfo", "r");
if (fp == nullptr) {
return kErrorInvalidValue;
}

char label[64];
std::size_t value = 0;
while (std::fscanf(fp, "%63s %zu %*s", label, &value) == 2) {
if (std::strcmp(label, "MemTotal:") == 0) {
*total = value * 1024;
} else if (std::strcmp(label, "MemAvailable:") == 0) {
*free = value * 1024;
}
}
std::fclose(fp);
#endif

return *total == 0 ? kErrorInvalidValue : kSuccess;
}

static Error Memcpy(void* dst, const void* src, std::size_t size, int) {
if ((dst == nullptr || src == nullptr) && size != 0) {
return kErrorInvalidValue;
Expand All @@ -82,6 +128,11 @@ struct Runtime<Device::Type::kCpu> : RuntimeBase<Runtime<Device::Type::kCpu>> {
return kSuccess;
}

static Error MemcpyAsync(void* dst, const void* src, std::size_t size,
int kind, Stream) {
return kErrorInvalidValue;
}

static Error Memset(void* ptr, int value, std::size_t count) {
if (ptr == nullptr && count != 0) {
return kErrorInvalidValue;
Expand All @@ -92,11 +143,80 @@ struct Runtime<Device::Type::kCpu> : RuntimeBase<Runtime<Device::Type::kCpu>> {
return kSuccess;
}

static Error MemcpyAsync(void* dst, const void* src, std::size_t size,
int kind, Stream) {
static Error MemsetAsync(void* ptr, int value, std::size_t count, Stream) {
return kErrorInvalidValue;
}

static Error StreamCreate(Stream* stream) {
if (stream == nullptr) {
return kErrorInvalidValue;
}

*stream = nullptr;

return kSuccess;
}

static Error StreamDestroy(Stream) { return kSuccess; }

static Error StreamSynchronize(Stream) { return kSuccess; }

static Error StreamWaitEvent(Stream, Event, unsigned int) { return kSuccess; }

using CpuEvent = std::chrono::steady_clock::time_point;

static Error EventCreate(Event* event) {
if (event == nullptr) {
return kErrorInvalidValue;
}

*event = new (std::nothrow) CpuEvent(std::chrono::steady_clock::now());

return *event == nullptr ? kErrorMemoryAllocation : kSuccess;
}

static Error EventCreateWithFlags(Event* event, unsigned int) {
return EventCreate(event);
}

static Error EventRecord(Event event, Stream) {
if (event == nullptr) {
return kErrorInvalidValue;
}

*static_cast<CpuEvent*>(event) = std::chrono::steady_clock::now();

return kSuccess;
}

static Error EventQuery(Event event) {
return event == nullptr ? kErrorInvalidValue : kSuccess;
}

static Error EventSynchronize(Event event) {
return event == nullptr ? kErrorInvalidValue : kSuccess;
}

static Error EventDestroy(Event event) {
delete static_cast<CpuEvent*>(event);

return kSuccess;
}

static Error EventElapsedTime(float* ms, Event start, Event end) {
if (ms == nullptr || start == nullptr || end == nullptr) {
return kErrorInvalidValue;
}

const auto* start_time = static_cast<const CpuEvent*>(start);
const auto* end_time = static_cast<const CpuEvent*>(end);
const auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
*end_time - *start_time);
*ms = static_cast<float>(duration.count()) / 1000.0f;

return kSuccess;
}

static constexpr int kMemcpyHostToHost = 0;

static constexpr int kMemcpyHostToDevice = 1;
Expand Down
Loading