Skip to content

Commit 5d29294

Browse files
Add final changes
1 parent 4ea5543 commit 5d29294

File tree

10 files changed

+131
-122
lines changed

10 files changed

+131
-122
lines changed

.fernignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ jest.config.js
88
.github/workflows/ci.yml
99
LICENSE
1010
src/BedrockClient.ts
11+
src/AwsClient.ts
12+
src/SagemakerClient.ts
1113
src/index.ts

src/AwsClient.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { CohereClient } from "./Client";
33

44
export class AwsClient extends CohereClient {
55
constructor(_options: CohereClient.Options & AwsProps) {
6-
_options.token = "n/a";
6+
_options.token = "n/a"; // AWS clients don't need a token but setting to this to a string so Fern doesn't complain
77
super(_options);
88
}
99
}

src/BedrockClient.ts

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,9 @@
1-
import { AwsProps } from 'aws-utils';
1+
import { AwsProps, fetchOverride } from './aws-utils';
22
import { AwsClient } from './AwsClient';
33
import { CohereClient } from "./Client";
4-
import * as serializers from "./serialization";
5-
6-
7-
const withTempEnv = async <R>(updateEnv: () => void, fn: () => Promise<R>): Promise<R> => {
8-
const previousEnv = { ...process.env };
9-
10-
try {
11-
updateEnv();
12-
return await fn();
13-
} finally {
14-
process.env = previousEnv;
15-
}
16-
};
17-
18-
const streamingResponseParser: Record<string, any> = {
19-
"chat": serializers.StreamedChatResponse,
20-
"generate": serializers.GenerateStreamedResponse,
21-
}
22-
23-
const nonStreamedResponseParser: Record<string, any> = {
24-
"chat": serializers.NonStreamedChatResponse,
25-
"embed": serializers.EmbedResponse,
26-
"generate": serializers.Generation,
27-
}
28-
29-
const mapResponseFromBedrock = async (streaming: boolean, endpoint: string, obj: {}) => {
30-
31-
const parser = streaming ? streamingResponseParser[endpoint] : nonStreamedResponseParser[endpoint];
32-
33-
const config = {
34-
unrecognizedObjectKeys: "passthrough",
35-
allowUnrecognizedUnionMembers: true,
36-
allowUnrecognizedEnumValues: true,
37-
skipValidation: true,
38-
breadcrumbsPrefix: ["response"],
39-
}
40-
41-
const parsed = await parser.parseOrThrow(obj, config)
42-
return parser.jsonOrThrow(parsed, config);
43-
}
444

455
export class BedrockClient extends AwsClient {
46-
constructor(protected readonly _options: CohereClient.Options & AwsProps) {
47-
_options.token = "n/a";
48-
super(_options);
6+
constructor(_options: CohereClient.Options & AwsProps) {
7+
super({ ..._options, fetcher: fetchOverride("bedrock", _options) });
498
}
509
}

src/SagemakerClient.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import { AwsProps, fetchOverride } from './aws-utils';
55

66
export class SagemakerClient extends AwsClient {
77
constructor(_options: CohereClient.Options & AwsProps) {
8-
_options.token = "n/a";
98
super({ ..._options, fetcher: fetchOverride("sagemaker", _options) });
109
}
1110
}

src/aws-utils.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ export const getUrl = (
7575
}[platform];
7676
}
7777

78-
export const getAuthHeaders = async (url: URL, method: string, headers: Record<string, string>, body: unknown, props: AwsProps): Promise<Record<string, string>> => {
78+
export const getAuthHeaders = async (url: URL, method: string, headers: Record<string, string>, body: unknown, service: AwsPlatform, props: AwsProps): Promise<Record<string, string>> => {
7979
const providerChain = fromNodeProviderChain();
8080

8181
const credentials = await withTempEnv(
@@ -101,7 +101,7 @@ export const getAuthHeaders = async (url: URL, method: string, headers: Record<s
101101
);
102102

103103
const signer = new SignatureV4({
104-
service: 'sagemaker',
104+
service,
105105
region: props.awsRegion,
106106
credentials,
107107
sha256: Sha256,
@@ -182,6 +182,7 @@ export const fetchOverride = (platform: AwsPlatform, {
182182
fetcherArgs.method,
183183
fetcherArgs.headers as Record<string, string>,
184184
JSON.stringify(bodyJson),
185+
platform,
185186
{
186187
awsRegion,
187188
awsAccessKey,

src/index.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
export * as Cohere from "./api";
1+
export { BedrockClient } from "./BedrockClient";
22
export { CohereClient } from "./Client";
3+
export { SagemakerClient } from "./SagemakerClient";
4+
export * as Cohere from "./api";
35
export { CohereEnvironment } from "./environments";
46
export { CohereError, CohereTimeoutError } from "./errors";
5-
export { BedrockClient } from "./BedrockClient"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Jest Snapshot v1, https://goo.gl/fbAQLP
2+
3+
exports[`test aws utils parseAWSEvent 1`] = `undefined`;
4+
5+
exports[`test aws utils parseAWSEvent 2`] = `undefined`;
6+
7+
exports[`test aws utils parseAWSEvent 3`] = `
8+
{
9+
"event_type": "text-generation",
10+
"is_finished": false,
11+
"text": "Hello",
12+
}
13+
`;
14+
15+
exports[`test aws utils parseAWSEvent 4`] = `
16+
{
17+
"event_type": "text-generation",
18+
"is_finished": false,
19+
"text": "!",
20+
}
21+
`;

src/test/aws-util-tests.test.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import { describe, expect, test } from "@jest/globals";
2+
import { parseAWSEvent } from "../aws-utils";
3+
4+
describe("test aws utils", () => {
5+
test.each([
6+
`'�K*��z :event-typechunk'`,
7+
`':content-typeapplication/json'`,
8+
`':message-typeevent{"bytes":"eyJldmVudF90eXBlIjoidGV4dC1nZW5lcmF0aW9uIiwiaXNfZmluaXNoZWQiOmZhbHNlLCJ0ZXh0IjoiSGVsbG8ifQ=="}�B@Q�K�;~t :event-typechunk'`,
9+
`':message-typeevent{"bytes":"eyJldmVudF90eXBlIjoidGV4dC1nZW5lcmF0aW9uIiwiaXNfZmluaXNoZWQiOmZhbHNlLCJ0ZXh0IjoiISJ9"}V�6��K�ش :event-typechunk'`,
10+
])("parseAWSEvent ", (event) => {
11+
expect(parseAWSEvent(event)).toMatchSnapshot();
12+
})
13+
});

src/test/bedrock-tests.test.ts

Lines changed: 80 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,105 @@
11
import { describe, expect, test } from "@jest/globals";
2+
import { BedrockClient } from "BedrockClient";
3+
import { SagemakerClient } from "SagemakerClient";
24
import { AwsEndpoint, AwsPlatform } from "aws-utils";
35
import { AwsClient } from "../AwsClient";
4-
import { SagemakerClient } from "../SagemakerClient";
5-
66

77
let cohere: AwsClient;
88

9-
109
const config = {
1110
awsRegion: "us-east-1",
12-
}
11+
awsAccessKey: "...",
12+
awsSecretKey: "...",
13+
awsSessionToken: "...",
14+
};
1315

1416
const models: Record<AwsPlatform, Record<AwsEndpoint, string>> = {
15-
"bedrock": {
16-
"generate": "cohere.command-text-v14",
17-
"embed": "cohere.embed-multilingual-v3",
18-
"chat": "cohere.command-r-plus-v1:0"
17+
bedrock: {
18+
generate: "cohere.command-text-v14",
19+
embed: "cohere.embed-multilingual-v3",
20+
chat: "cohere.command-r-plus-v1:0",
1921
},
20-
"sagemaker": {
21-
"generate": "cohere-command-light",
22-
"embed": "xxxx",
23-
"chat": "xxx"
24-
}
25-
}
26-
27-
describe.each<AwsPlatform>(["sagemaker"])("test sdk", (platform) => {
28-
cohere = {
29-
"bedrock": new AwsClient(config),
30-
"sagemaker": new SagemakerClient(config)
31-
}[platform]!;
32-
33-
test.concurrent("generate works", async () => {
34-
const generate = await cohere.generate({
35-
prompt: "Please explain to me how LLMs work",
36-
temperature: 0,
37-
model: models[platform].generate,
38-
});
39-
40-
expect(generate.generations[0].text).toBeDefined();
41-
});
42-
43-
test.concurrent("generate stream works", async () => {
44-
const generate = await cohere.generateStream({
45-
prompt: "Please explain to me how LLMs work",
46-
temperature: 0,
47-
model: models[platform].generate,
22+
sagemaker: {
23+
generate: "cohere-command-light",
24+
embed: "cohere-embed-multilingual-v3",
25+
chat: "cohere-command-plus",
26+
},
27+
};
28+
29+
30+
// skip until we have the right auth in ci
31+
describe.each<AwsPlatform>(["bedrock"])(
32+
"test sdk",
33+
(platform) => {
34+
cohere = {
35+
"bedrock": new BedrockClient(config),
36+
"sagemaker": new SagemakerClient(config)
37+
}[platform]!;
38+
39+
test.skip("generate works", async () => {
40+
const generate = await cohere.generate({
41+
prompt: "Please explain to me how LLMs work",
42+
temperature: 0,
43+
model: models[platform].generate,
44+
});
45+
46+
expect(generate.generations[0].text).toBeDefined();
4847
});
4948

50-
const chunks = [];
49+
test.skip("generate stream works", async () => {
50+
const generate = await cohere.generateStream({
51+
prompt: "Please explain to me how LLMs work",
52+
temperature: 0,
53+
model: models[platform].generate,
54+
});
5155

52-
for await (const chunk of generate) {
53-
chunks.push(chunk);
54-
}
56+
const chunks = [];
5557

56-
expect(chunks[0].eventType).toMatchInlineSnapshot(`"stream-start"`);
57-
expect(chunks[1].eventType).toMatchInlineSnapshot(`"text-generation"`);
58-
expect(chunks[chunks.length - 1].eventType).toMatchInlineSnapshot(`"stream-end"`);
59-
});
58+
for await (const chunk of generate) {
59+
chunks.push(chunk);
60+
}
6061

61-
test.concurrent("embed works", async () => {
62-
const embed = await cohere.embed({
63-
texts: ["hello", "goodbye"],
64-
model: models[platform].embed,
65-
inputType: "search_document",
62+
expect(chunks[0].eventType).toMatchInlineSnapshot(`"stream-start"`);
63+
expect(chunks[1].eventType).toMatchInlineSnapshot(`"text-generation"`);
64+
expect(chunks[chunks.length - 1].eventType).toMatchInlineSnapshot(`"stream-end"`);
6665
});
6766

68-
if (embed.responseType === "embeddings_by_type") {
69-
expect(embed.embeddings?.float?.[0]).toBeDefined();
70-
}
71-
});
67+
test.skip("embed works", async () => {
68+
const embed = await cohere.embed({
69+
texts: ["hello", "goodbye"],
70+
model: models[platform].embed,
71+
inputType: "search_document",
72+
});
7273

73-
test.concurrent("chat works", async () => {
74-
const chat = await cohere.chat({
75-
model: models[platform].chat,
76-
message: "send me a short message",
77-
temperature: 0,
74+
if (embed.responseType === "embeddings_by_type") {
75+
expect(embed.embeddings?.float?.[0]).toBeDefined();
76+
}
7877
});
79-
});
8078

81-
test.concurrent("chat stream works", async () => {
82-
const chat = await cohere.chatStream({
83-
model: models[platform].chat,
84-
message: "send me a short message",
85-
temperature: 0,
79+
test.skip("chat works", async () => {
80+
const chat = await cohere.chat({
81+
model: models[platform].chat,
82+
message: "send me a short message",
83+
temperature: 0,
84+
});
8685
});
8786

88-
const chunks = [];
87+
test.skip("chat stream works", async () => {
88+
const chat = await cohere.chatStream({
89+
model: models[platform].chat,
90+
message: "send me a short message",
91+
temperature: 0,
92+
});
93+
94+
const chunks = [];
8995

90-
for await (const chunk of chat) {
91-
chunks.push(chunk);
92-
}
96+
for await (const chunk of chat) {
97+
chunks.push(chunk);
98+
}
9399

94-
expect(chunks[0].eventType).toMatchInlineSnapshot(`"stream-start"`);
95-
expect(chunks[chunks.length - 1].eventType).toMatchInlineSnapshot(`"stream-end"`);
96-
});
97-
}, 5000);
100+
expect(chunks[0].eventType).toMatchInlineSnapshot(`"text-generation"`);
101+
expect(chunks[chunks.length - 1].eventType).toMatchInlineSnapshot(`"stream-end"`);
102+
});
103+
},
104+
5000
105+
);

yarn.lock

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3582,6 +3582,11 @@ [email protected]:
35823582
resolved "https://registry.npmjs.org/typescript/-/typescript-4.6.4.tgz"
35833583
integrity sha512-9ia/jWHIEbo49HfjrLGfKbZSuWo9iTMwXO+Ca3pRsSpbsMbc7/IU8NKdCZVRRBafVPGnoJeFL76ZOAA84I9fEg==
35843584

3585+
undici-types@~5.26.4:
3586+
version "5.26.5"
3587+
resolved "https://registry.yarnpkg.com/undici-types/-/undici-types-5.26.5.tgz#bcd539893d00b56e964fd2657a4866b221a65617"
3588+
integrity sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==
3589+
35853590
universalify@^0.2.0:
35863591
version "0.2.0"
35873592
resolved "https://registry.npmjs.org/universalify/-/universalify-0.2.0.tgz"

0 commit comments

Comments
 (0)