Skip to content

Commit 7e23b04

Browse files
committed
feat: PUT when using in-process patching
resolves #35
1 parent 8bf03f6 commit 7e23b04

File tree

4 files changed

+107
-2
lines changed

4 files changed

+107
-2
lines changed

fakesnow/cursor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import TYPE_CHECKING, Any, cast
1111

1212
import duckdb
13+
import pandas as pd
1314
import pyarrow # needed by fetch_arrow_table()
1415
import snowflake.connector.converter
1516
import snowflake.connector.errors
@@ -175,6 +176,9 @@ def execute(
175176
transformed = self._transform(exp, params)
176177
self._execute(transformed, params)
177178

179+
if not kwargs.get("server") and (put_stage_data := transformed.args.get("put_stage_data")): # pyright: ignore[reportPossiblyUnboundVariable]
180+
self._put_files(put_stage_data)
181+
178182
return self
179183
except snowflake.connector.errors.ProgrammingError as e:
180184
self._sqlstate = e.sqlstate
@@ -188,6 +192,13 @@ def execute(
188192
msg = f"{e} not implemented. Please raise an issue via https://github.com/tekumara/fakesnow/issues/new"
189193
raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=9999, sqlstate="99999") from e
190194

195+
def _put_files(self, put_stage_data: stage.UploadCommandDict) -> None:
196+
results = stage.upload_files(put_stage_data)
197+
_df = pd.DataFrame.from_records(results)
198+
self._duck_conn.execute("select * from _df")
199+
self._arrow_table = self._duck_conn.fetch_arrow_table()
200+
self._rowcount = self._arrow_table.num_rows
201+
191202
def check_db_and_schema(self, expression: exp.Expression) -> None:
192203
no_database, no_schema = checks.is_unqualified_table_expression(expression)
193204

fakesnow/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async def query_request(request: Request) -> JSONResponse:
9494

9595
try:
9696
# only a single sql statement is sent at a time by the python snowflake connector
97-
cur = await run_in_threadpool(conn.cursor().execute, sql_text, binding_params=params)
97+
cur = await run_in_threadpool(conn.cursor().execute, sql_text, binding_params=params, server=True)
9898
rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001
9999

100100
expr = cur._last_transformed # noqa: SLF001

fakesnow/transforms/stage.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from __future__ import annotations
22

33
import datetime
4+
import os
45
import tempfile
56
from pathlib import PurePath
7+
from typing import Any, TypedDict
68
from urllib.parse import urlparse
79
from urllib.request import url2pathname
810

911
import snowflake.connector.errors
1012
import sqlglot
13+
from snowflake.connector.file_util import SnowflakeFileUtil
1114
from sqlglot import exp
1215

1316
from fakesnow.params import MutableParams
@@ -16,6 +19,22 @@
1619
LOCAL_BUCKET_PATH = tempfile.mkdtemp(prefix="fakesnow_bucket_")
1720

1821

22+
class StageInfoDict(TypedDict):
23+
locationType: str
24+
location: str
25+
creds: dict[str, Any]
26+
27+
28+
class UploadCommandDict(TypedDict):
29+
stageInfo: StageInfoDict
30+
src_locations: list[str]
31+
parallel: int
32+
autoCompress: bool
33+
sourceCompression: str
34+
overwrite: bool
35+
command: str
36+
37+
1938
def create_stage(
2039
expression: exp.Expression,
2140
current_database: str | None,
@@ -240,3 +259,35 @@ def list_stage_files_sql(stage_name: str) -> str:
240259
strftime(last_modified, '%a, %d %b %Y %H:%M:%S GMT') as last_modified
241260
from read_blob('{sdir}/*')
242261
"""
262+
263+
264+
def upload_files(put_stage_data: UploadCommandDict) -> list[dict[str, Any]]:
265+
results = []
266+
for src in put_stage_data["src_locations"]:
267+
basename = os.path.basename(src)
268+
stage_dir = put_stage_data["stageInfo"]["location"]
269+
270+
os.makedirs(stage_dir, exist_ok=True)
271+
gzip_file_name, target_size = SnowflakeFileUtil.compress_file_with_gzip(src, stage_dir)
272+
273+
# Rename to match expected .gz extension on upload
274+
target_basename = basename + ".gz"
275+
target = os.path.join(stage_dir, target_basename)
276+
os.replace(gzip_file_name, target)
277+
278+
target_size = os.path.getsize(target)
279+
source_size = os.path.getsize(src)
280+
281+
results.append(
282+
{
283+
"source": basename,
284+
"target": target_basename,
285+
"source_size": source_size,
286+
"target_size": target_size,
287+
"source_compression": "NONE",
288+
"target_compression": "GZIP",
289+
"status": "UPLOADED",
290+
"message": "",
291+
}
292+
)
293+
return results

tests/test_stage.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
2+
import tempfile
13
from datetime import timezone
24

35
import pytest
46
import snowflake.connector.cursor
5-
from dirty_equals import IsNow
7+
from dirty_equals import IsDatetime, IsNow
68

79

810
def test_create_stage(dcur: snowflake.connector.cursor.SnowflakeCursor):
@@ -111,3 +113,44 @@ def test_create_stage_qmark_quoted(_fakesnow: None):
111113
):
112114
dcur.execute("CREATE STAGE identifier(?)", ('"stage1"',))
113115
assert dcur.fetchall() == [{"status": "Stage area stage1 successfully created."}]
116+
117+
118+
def test_put_list(dcur: snowflake.connector.cursor.DictCursor) -> None:
119+
with tempfile.NamedTemporaryFile(mode="w+", suffix=".csv") as temp_file:
120+
data = "1,2\n"
121+
temp_file.write(data)
122+
temp_file.flush()
123+
temp_file_path = temp_file.name
124+
temp_file_basename = os.path.basename(temp_file_path)
125+
126+
dcur.execute("CREATE STAGE stage4")
127+
dcur.execute(f"PUT 'file://{temp_file_path}' @stage4")
128+
assert dcur.fetchall() == [
129+
{
130+
"source": temp_file_basename,
131+
"target": f"{temp_file_basename}.gz",
132+
"source_size": len(data),
133+
"target_size": 42, # GZIP compressed size
134+
"source_compression": "NONE",
135+
"target_compression": "GZIP",
136+
"status": "UPLOADED",
137+
"message": "",
138+
}
139+
]
140+
141+
dcur.execute("LIST @stage4")
142+
results = dcur.fetchall()
143+
assert len(results) == 1
144+
assert results[0] == {
145+
"name": f"stage4/{temp_file_basename}.gz",
146+
"size": 42,
147+
"md5": "29498d110c32a756df8109e70d22fa36",
148+
"last_modified": IsDatetime(
149+
# string in RFC 7231 date format (e.g. 'Sat, 31 May 2025 08:50:51 GMT')
150+
format_string="%a, %d %b %Y %H:%M:%S GMT"
151+
),
152+
}
153+
154+
# fully qualified stage name quoted
155+
dcur.execute('CREATE STAGE db1.schema1."stage5"')
156+
dcur.execute(f"PUT 'file://{temp_file_path}' @db1.schema1.\"stage5\"")

0 commit comments

Comments
 (0)