Skip to content

Commit e740ff8

Browse files
authored
[code refactor]decompose RAL interface and custom operators (#1141)
1 parent 4e5f275 commit e740ff8

File tree

114 files changed

+659
-581
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

114 files changed

+659
-581
lines changed

docs/developers/add_custom_call.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ as an example to introduce how to add a custom call operator step by step.
1515

1616
BladeDISC provides a macro `TAO_RAL_API` to register a custom call operator. To
1717
make the code structure clearly, please create a new `transpose_impl.cc` file
18-
under the directory: `tao_compiler/mlir/xla/ral/context/`:
18+
under the directory: `tao_compiler/mlir/ral/context/`:
1919

2020
``` c++
2121
void ral_gpu_transpose_2d(ExecutionContext* ctx, void* stream_handle,
@@ -36,7 +36,7 @@ user-defined.
3636
Please note that `ral_gpu_transpose_2d` function will be called at runtime, so we
3737
can launch the GPU kernel in the function `LaunchTranspose2DKernel`, Usually you
3838
should add the kernel implementation to the directory
39-
`tao_compiler/mlir/xla/ral/context/custom_library/` .
39+
`tao_compiler/mlir/ral/context/custom_library/` .
4040
4141
### Step2: Translate LMHLO Operator to Custom Call Operator
4242

pytorch_blade/bazel_build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def pybind11_cflags(self):
6464
def __init__(self, *args, **kwargs):
6565
super().__init__(*args, **kwargs)
6666
self.targets = [
67-
"@org_disc_compiler//mlir/xla/ral:libral_base_context.so",
67+
"@org_disc_compiler//mlir/ral:libral_base_context.so",
68+
"@org_disc_compiler//mlir/custom_ops:libdisc_custom_ops.so",
6869
"//pytorch_blade:libtorch_blade.so",
6970
"//pytorch_blade:_torch_blade.so",
7071
"//tests/mhlo/torch-mlir-opt:torch-mlir-opt",

pytorch_blade/distribution.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pytorch_blade/_torch_blade.so
22
pytorch_blade/libtorch_blade.so
3-
external/org_disc_compiler/mlir/xla/ral/libral_base_context.so
3+
external/org_disc_compiler/mlir/ral/libral_base_context.so
4+
external/org_disc_compiler/mlir/custom_ops/libdisc_custom_ops.so
45
external/org_disc_compiler/mlir/disc/disc_compiler_main

pytorch_blade/pytorch_blade/compiler/mlir/runtime/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ cc_library(
1111
deps = [
1212
"//pytorch_blade/common_utils:torch_blade_common",
1313
"//pytorch_blade/compiler/jit:torch_blade_jit",
14-
"@org_disc_compiler//mlir/xla/ral:ral_base_context_lib",
14+
"@org_disc_compiler//mlir/ral:ral_base_context_lib",
15+
"@org_disc_compiler//mlir/custom_ops:disc_custom_ops_lib",
1516
"@local_org_torch//:ATen",
1617
"@local_org_torch//:libtorch",
1718
] + if_cuda_is_configured([

pytorch_blade/pytorch_blade/compiler/mlir/runtime/ral_context.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
#endif // TORCH_BLADE_USE_ROCM
2929
#endif // TORCH_BLADE_BUILD_WITH_CUDA
3030

31-
#include "mlir/xla/ral/ral_api.h"
31+
#include "mlir/ral/ral_api.h"
32+
3233
#include "pytorch_blade/common_utils/utils.h"
3334

3435
#ifdef TORCH_BLADE_USE_ROCM

pytorch_blade/pytorch_blade/compiler/mlir/runtime/ral_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ using CUDAStream = ::c10::hip::HIPStream;
3030
#else // TORCH_BLADE_USE_ROCM
3131
#include <c10/cuda/CUDAStream.h>
3232
#endif // TORCH_BLADE_USE_ROCM
33-
#include "mlir/xla/ral/context/base/cuda/cuda_context_impl.h"
33+
#include "mlir/ral/context/base/cuda/cuda_context_impl.h"
3434
#endif // TORCH_BLADE_BUILD_WITH_CUDA
3535
// TODO(disc): figure out why the bazel does not trigger re-compile this file
3636
// after we update ral.
37-
#include "mlir/xla/ral/context/base/cpu/cpu_context_impl.h"
37+
#include "mlir/ral/context/base/cpu/cpu_context_impl.h"
3838

3939
#include "pytorch_blade/common_utils/macros.h"
4040
#include "pytorch_blade/common_utils/tempfs.h"

scripts/python/tao_build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def test_tao_compiler(root, args):
358358
TARGET_DISC_TRANSFORMS_TEST = "//mlir/disc/transforms/tests/..."
359359
TARGET_DISC_E2E_TEST = "//mlir/disc/tests/..."
360360
TARGET_DISC_RAL_TESTS = [
361-
"//mlir/xla/ral:ral_metadata_test"
361+
"//mlir/ral:ral_metadata_test"
362362
]
363363
TARGET_DISC_PDLL_TESTS = [
364364
"//mlir/disc/tools/disc-pdll/tests/..."

tao/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ cc_library(
103103
),
104104
deps = [
105105
":version_header",
106-
"@org_tao_compiler//mlir/xla/ral:ral_bridge",
106+
"@org_tao_compiler//mlir/custom_ops:disc_custom_ops_bridge",
107107
] + if_platform_alibaba([
108108
"@blade_service_common//blade_service_common:blade_service_common_deps",
109109
]) + if_internal_serving(

tao/tao_bridge/executable.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include <unordered_map>
1919
#include <vector>
2020

21-
#include "mlir/xla/ral/context/tensorflow/tf_context_impl.h"
21+
#include "mlir/ral/context/tensorflow/tf_context_impl.h"
2222
#include "tao_bridge/common.h"
2323
#include "tao_bridge/tao_compilation_result.pb.h"
2424
#include "tensorflow/core/framework/op_kernel.h"

tao/tao_bridge/mlir/mlir_executable.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "tao_bridge/kernels/disc_launch.h"
2020
#endif // PLATFORM_ALIBABA
2121

22-
#include "mlir/xla/ral/ral_api.h"
22+
#include "mlir/ral/ral_api.h"
2323

2424
namespace tensorflow {
2525
namespace tao {

tao/tao_bridge/ral/CMakeLists.txt

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -76,44 +76,44 @@ endif()
7676

7777
list(APPEND RAL_HDRS
7878
"tensorflow/compiler/mlir/xla/compile_metadata.pb.h"
79-
"tensorflow/compiler/mlir/xla/ral/context/context_util.h"
80-
"tensorflow/compiler/mlir/xla/ral/context/common_context_impl.h"
81-
"tensorflow/compiler/mlir/xla/ral/context/stream_executor_based_impl.h"
82-
"tensorflow/compiler/mlir/xla/ral/context/tensorflow/tf_context_impl.h"
83-
"tensorflow/compiler/mlir/xla/ral/device/cpu/cpu_driver.h"
84-
"tensorflow/compiler/mlir/xla/ral/device/gpu/gpu_driver.h"
85-
"tensorflow/compiler/mlir/xla/ral/ral_api.h"
86-
"tensorflow/compiler/mlir/xla/ral/ral_base.h"
87-
"tensorflow/compiler/mlir/xla/ral/ral_context.h"
88-
"tensorflow/compiler/mlir/xla/ral/ral_driver.h"
89-
"tensorflow/compiler/mlir/xla/ral/ral_helper.h"
90-
"tensorflow/compiler/mlir/xla/ral/ral_logging.h"
79+
"tensorflow/compiler/mlir/ral/context/context_util.h"
80+
"tensorflow/compiler/mlir/ral/context/common_context_impl.h"
81+
"tensorflow/compiler/mlir/ral/context/stream_executor_based_impl.h"
82+
"tensorflow/compiler/mlir/ral/context/tensorflow/tf_context_impl.h"
83+
"tensorflow/compiler/mlir/ral/device/cpu/cpu_driver.h"
84+
"tensorflow/compiler/mlir/ral/device/gpu/gpu_driver.h"
85+
"tensorflow/compiler/mlir/ral/ral_api.h"
86+
"tensorflow/compiler/mlir/ral/ral_base.h"
87+
"tensorflow/compiler/mlir/ral/ral_context.h"
88+
"tensorflow/compiler/mlir/ral/ral_driver.h"
89+
"tensorflow/compiler/mlir/ral/ral_helper.h"
90+
"tensorflow/compiler/mlir/ral/ral_logging.h"
9191
)
9292

9393
list(APPEND RAL_SRCS
9494
"tensorflow/compiler/mlir/xla/compile_metadata.pb.cc"
95-
"tensorflow/compiler/mlir/xla/ral/context/common_context_impl.cc"
96-
"tensorflow/compiler/mlir/xla/ral/context/stream_executor_based_impl.cc"
97-
"tensorflow/compiler/mlir/xla/ral/context/tensorflow/tf_context_impl.cc"
98-
"tensorflow/compiler/mlir/xla/ral/context/tensorflow/tf_kernel_impl.cc"
99-
"tensorflow/compiler/mlir/xla/ral/device/cpu/cpu_driver.cc"
100-
"tensorflow/compiler/mlir/xla/ral/device/gpu/gpu_driver.cc"
101-
"tensorflow/compiler/mlir/xla/ral/ral_api.cc"
102-
"tensorflow/compiler/mlir/xla/ral/ral_context.cc"
103-
"tensorflow/compiler/mlir/xla/ral/ral_helper.cc"
104-
"tensorflow/compiler/mlir/xla/ral/ral_logging.cc"
95+
"tensorflow/compiler/mlir/ral/context/common_context_impl.cc"
96+
"tensorflow/compiler/mlir/ral/context/stream_executor_based_impl.cc"
97+
"tensorflow/compiler/mlir/ral/context/tensorflow/tf_context_impl.cc"
98+
"tensorflow/compiler/mlir/ral/context/tensorflow/tf_kernel_impl.cc"
99+
"tensorflow/compiler/mlir/ral/device/cpu/cpu_driver.cc"
100+
"tensorflow/compiler/mlir/ral/device/gpu/gpu_driver.cc"
101+
"tensorflow/compiler/mlir/ral/ral_api.cc"
102+
"tensorflow/compiler/mlir/ral/ral_context.cc"
103+
"tensorflow/compiler/mlir/ral/ral_helper.cc"
104+
"tensorflow/compiler/mlir/ral/ral_logging.cc"
105105
)
106106

107107
#TODO: revisit this when support DCU in tf bridge
108108
if(${TAO_CUDA})
109109
list(APPEND RAL_SRCS
110-
"tensorflow/compiler/mlir/xla/ral/context/dynamic_sort_impl.cc"
111-
"tensorflow/compiler/mlir/xla/ral/context/random_impl.cc"
112-
"tensorflow/compiler/mlir/xla/ral/context/common_context_impl_cuda.cc"
110+
"tensorflow/compiler/mlir/ral/context/dynamic_sort_impl.cc"
111+
"tensorflow/compiler/mlir/ral/context/random_impl.cc"
112+
"tensorflow/compiler/mlir/ral/context/common_context_impl_cuda.cc"
113113
)
114114
list(APPEND RAL_CU_SRCS
115-
"tensorflow/compiler/mlir/xla/ral/context/custom_library/dynamic_sort.cu.cc"
116-
"tensorflow/compiler/mlir/xla/ral/context/custom_library/random_gpu.cu.cc"
115+
"tensorflow/compiler/mlir/ral/context/custom_library/dynamic_sort.cu.cc"
116+
"tensorflow/compiler/mlir/ral/context/custom_library/random_gpu.cu.cc"
117117
)
118118

119119
foreach(file ${RAL_CU_SRCS})
@@ -127,23 +127,23 @@ endif()
127127

128128
if(${TAO_ROCM} OR ${TAO_DCU})
129129
list(APPEND RAL_SRCS
130-
"tensorflow/compiler/mlir/xla/ral/context/common_context_impl_cuda.cc"
130+
"tensorflow/compiler/mlir/ral/context/common_context_impl_cuda.cc"
131131
)
132132
include_directories("${CMAKE_SOURCE_DIR}")
133133
list(APPEND RAL_HDRS
134-
"tensorflow/compiler/mlir/xla/ral/context/common_context_impl_cuda.h"
134+
"tensorflow/compiler/mlir/ral/context/common_context_impl_cuda.h"
135135
)
136136
endif()
137137

138138
#CustomLib Support for DCU to be revisited later
139139
if(${TAO_ROCM})
140140
list(APPEND RAL_SRCS
141-
"tensorflow/compiler/mlir/xla/ral/context/dynamic_sort_impl.cc"
142-
"tensorflow/compiler/mlir/xla/ral/context/random_impl.cc"
141+
"tensorflow/compiler/mlir/ral/context/dynamic_sort_impl.cc"
142+
"tensorflow/compiler/mlir/ral/context/random_impl.cc"
143143
)
144144
list(APPEND RAL_CU_SRCS
145-
"tensorflow/compiler/mlir/xla/ral/context/custom_library/dynamic_sort.cu.cc"
146-
"tensorflow/compiler/mlir/xla/ral/context/custom_library/random_gpu.cu.cc"
145+
"tensorflow/compiler/mlir/ral/context/custom_library/dynamic_sort.cu.cc"
146+
"tensorflow/compiler/mlir/ral/context/custom_library/random_gpu.cu.cc"
147147
)
148148
foreach(file ${RAL_CU_SRCS})
149149
set_source_files_properties(${file} PROPERTIES LANGUAGE HIP)
@@ -152,7 +152,7 @@ endif()
152152

153153
if(${TAO_ENABLE_MKLDNN})
154154
list(APPEND RAL_SRCS
155-
"tensorflow/compiler/mlir/xla/ral/context/common_context_impl_mkldnn.cc"
155+
"tensorflow/compiler/mlir/ral/context/common_context_impl_mkldnn.cc"
156156
)
157157
endif()
158158

tao_compiler/file_map

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
decoupling,tensorflow/compiler/decoupling
33
mlir/disc,tensorflow/compiler/mlir/disc
44
mlir/util,mlir/util
5-
mlir/xla/ral,tensorflow/compiler/mlir/xla/ral
5+
mlir/ral,tensorflow/compiler/mlir/ral
66
.bazelrc.user,tensorflow/../.bazelrc.user

0 commit comments

Comments
 (0)