Skip to content

feat: support s3 #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@ OSS_ACCESS_KEY_ID=
OSS_ACCESS_KEY_SECRET=
ENDPOINT=
BUCKET=

# S3
BUCKET=
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
AWS_DEFAULT_REGION=
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,31 @@ MINIO_ACCESS_KEY=
MINIO_SECRET_KEY=
```

### [S3](https://aws.amazon.com/s3/)

Usage:

```python
client = StoreFactory.new_client(
provider="S3", bucket=<bucket>
)

# Use endpoint when accessing S3 via a PrivateLink interface endpoint.
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-privatelink.html
client = StoreFactory.new_client(
provider="S3", bucket=<bucket>, endpoint=<endpoint>
)
```

Required environment variables:

```yaml
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
# If a region is not specified, the bucket is created in the S3 default region (us-east-1).
AWS_DEFAULT_REGION=
```

## Development

Once you want to run the integration tests, you should have a `.env` file locally, similar to the `.env.example`.
Expand Down
1 change: 1 addition & 0 deletions omnistore/objstore/constant.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
OBJECT_STORE_OSS = "OSS"
OBJECT_STORE_MINIO = "MINIO"
OBJECT_STORE_S3 = "S3"
6 changes: 4 additions & 2 deletions omnistore/objstore/objstore_factory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from omnistore.objstore.aliyun_oss import OSS
from omnistore.objstore.constant import OBJECT_STORE_OSS, OBJECT_STORE_MINIO
from omnistore.objstore.constant import OBJECT_STORE_OSS, OBJECT_STORE_MINIO, OBJECT_STORE_S3
from omnistore.objstore.minio import MinIO
from omnistore.objstore.s3 import S3
from omnistore.store import Store


class StoreFactory:
ObjStores = {
OBJECT_STORE_OSS: OSS,
OBJECT_STORE_MINIO: MinIO,
OBJECT_STORE_S3: S3,
}

@classmethod
def new_client(cls, provider: str, endpoint: str, bucket: str) -> Store:
def new_client(cls, provider: str, endpoint: str = None, bucket: str = None) -> Store:
objstore = cls.ObjStores[provider]
if not objstore:
raise KeyError(f"Unknown object store provider {provider}")
Expand Down
114 changes: 114 additions & 0 deletions omnistore/objstore/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import io
import os
from pathlib import Path

import boto3
from botocore.exceptions import ClientError

from omnistore.objstore.objstore import ObjStore


class S3(ObjStore):
def __init__(self, bucket: str, endpoint: str = None):
"""
Construct a new client to communicate with the AWS S3 provider.

AWS credentials are expected to be provided via environment variables:
- AWS_ACCESS_KEY_ID
- AWS_SECRET_ACCESS_KEY
- AWS_DEFAULT_REGION
"""
region = os.environ.get("AWS_DEFAULT_REGION")

# If a region is not specified, the bucket is created in the S3 default region (us-east-1).
# If the user explicitly provides an endpoint_url, the region is not used.
kwargs = {}
if endpoint:
kwargs['endpoint_url'] = endpoint
if region:
kwargs['region_name'] = region

self.client = boto3.client('s3', **kwargs)
self.resource = boto3.resource('s3', **kwargs)
self.bucket_name = bucket

# Make sure the bucket exists
try:
self.client.head_bucket(Bucket=bucket)
except ClientError as e:
# If bucket doesn't exist, create it
if e.response['Error']['Code'] == '404':
kwargs = {}
# For non us-east-1 region, we need to specify the LocationConstraint parameter when creating the bucket
if region:
kwargs['CreateBucketConfiguration'] = {
"LocationConstraint": region
}
self.client.create_bucket(Bucket=bucket, **kwargs)
else:
raise e

def create_dir(self, dirname: str):
if not dirname.endswith("/"):
dirname += "/"
empty_stream = io.BytesIO(b"")
self.client.put_object(Bucket=self.bucket_name, Key=dirname, Body=empty_stream)

def delete_dir(self, dirname: str):
if not dirname.endswith("/"):
dirname += "/"

bucket = self.resource.Bucket(self.bucket_name)
bucket.objects.filter(Prefix=dirname).delete()

def upload(self, src: str, dest: str):
self.client.upload_file(src, self.bucket_name, dest)

def upload_dir(self, src_dir: str, dest_dir: str):
for file in Path(src_dir).rglob("*"):
if file.is_file():
dest_path = f"{dest_dir}/{file.relative_to(src_dir)}"
self.upload(str(file), dest_path)
elif file.is_dir():
self.create_dir(f"{dest_dir}/{file.relative_to(src_dir)}/")

def download(self, src: str, dest: str):
self.client.download_file(self.bucket_name, src, dest)

def download_dir(self, src_dir: str, dest_dir: str):
if not src_dir.endswith("/"):
src_dir += "/"
path = Path(dest_dir)
if not path.exists():
path.mkdir(parents=True)

paginator = self.client.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=src_dir)

for page in pages:
if 'Contents' not in page:
continue

for obj in page['Contents']:
key = obj['Key']
if key.endswith('/'): # Skip directories
continue

file_path = Path(dest_dir, Path(key).relative_to(src_dir))
if not file_path.parent.exists():
file_path.parent.mkdir(parents=True, exist_ok=True)

self.download(key, str(file_path))

def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)

def exists(self, filename: str):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except ClientError as e:
if e.response['Error']['Code'] == '404':
return False
else:
raise e
65 changes: 65 additions & 0 deletions tests/integration_tests/objstore/test_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import shutil

import pytest
from dotenv import load_dotenv

from omnistore.objstore import StoreFactory
from omnistore.objstore.constant import OBJECT_STORE_S3

load_dotenv()

class TestS3:
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown(self):
print("Setting up the test environment.")
try:
os.makedirs("./test-tmp", exist_ok=True)
except Exception as e:
print(f"An error occurred: {e}")

yield

print("Tearing down the test environment.")
shutil.rmtree("./test-tmp")

def test_upload_and_download_files(self):
bucket = os.getenv("BUCKET")

client = StoreFactory.new_client(
provider=OBJECT_STORE_S3, bucket=bucket
)
assert False == client.exists("foo.txt")

with open("./test-tmp/foo.txt", "w") as file:
file.write("test")

client.upload("./test-tmp/foo.txt", "foo.txt")
assert True == client.exists("foo.txt")

client.download("foo.txt", "./test-tmp/bar.txt")
assert True == os.path.exists("./test-tmp/bar.txt")

client.delete("foo.txt")
assert False == client.exists("foo.txt")

def test_upload_and_download_dir(self):
bucket = os.getenv("BUCKET")

client = StoreFactory.new_client(
provider=OBJECT_STORE_S3, bucket=bucket
)
assert False == client.exists("/test/foo.txt")

os.makedirs("./test-tmp/test/111", exist_ok=True)
with open("./test-tmp/test/111/foo.txt", "w") as file:
file.write("test")

client.upload_dir("./test-tmp/test", "test")
assert True == client.exists("test/111/foo.txt")

client.download_dir("test", "./test-tmp/test1")
assert True == os.path.exists("./test-tmp/test1/111/foo.txt")

client.delete_dir("test")
assert False == client.exists("test/foo.txt")
Loading